support variables in modify request,response middlewares

This commit is contained in:
yusing 2024-12-03 10:20:18 +08:00
parent cd749ac6a4
commit ef1863f810
15 changed files with 291 additions and 114 deletions

View file

@ -12,6 +12,10 @@ const (
HeaderXForwardedPort = "X-Forwarded-Port" HeaderXForwardedPort = "X-Forwarded-Port"
HeaderXForwardedURI = "X-Forwarded-Uri" HeaderXForwardedURI = "X-Forwarded-Uri"
HeaderXRealIP = "X-Real-IP" HeaderXRealIP = "X-Real-IP"
HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme"
HeaderUpstreamHost = "X-GoDoxy-Upstream-Host"
HeaderUpstreamPort = "X-GoDoxy-Upstream-Port"
) )
func RemoveHop(h http.Header) { func RemoveHop(h http.Header) {

View file

@ -164,7 +164,15 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
} }
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, middlewares) mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
next(w, r)
},
}}, middlewares...))
if mid.before != nil { if mid.before != nil {
ori := rp.HandlerFunc ori := rp.HandlerFunc

View file

@ -19,12 +19,7 @@ func BuildMiddlewaresFromComposeFile(filePath string, eb *E.Builder) map[string]
eb.Add(err) eb.Add(err)
return nil return nil
} }
mids := BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb)
results := make(map[string]*Middleware, len(mids))
for k, v := range mids {
results[k+"@file"] = v
}
return results
} }
func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware { func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[string]*Middleware {
@ -40,7 +35,7 @@ func BuildMiddlewaresFromYAML(source string, data []byte, eb *E.Builder) map[str
if err != nil { if err != nil {
eb.Add(err.Subject(source)) eb.Add(err.Subject(source))
} else { } else {
middlewares[name] = chain middlewares[name+"@file"] = chain
} }
} }
return middlewares return middlewares

View file

@ -38,17 +38,23 @@ func init() {
// snakes and cases will be stripped on `Get` // snakes and cases will be stripped on `Get`
// so keys are lowercase without snake. // so keys are lowercase without snake.
allMiddlewares = map[string]*Middleware{ allMiddlewares = map[string]*Middleware{
"setxforwarded": SetXForwarded, "redirecthttp": RedirectHTTP,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP, "request": ModifyRequest,
"modifyresponse": ModifyResponse, "modifyrequest": ModifyRequest,
"modifyrequest": ModifyRequest, "response": ModifyResponse,
"errorpage": CustomErrorPage, "modifyresponse": ModifyResponse,
"customerrorpage": CustomErrorPage, "setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP, "realip": RealIP,
"cloudflarerealip": CloudflareRealIP, "cloudflarerealip": CloudflareRealIP,
"cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter, "cidrwhitelist": CIDRWhiteList,
"ratelimit": RateLimiter,
// !experimental // !experimental
"forwardauth": ForwardAuth, "forwardauth": ForwardAuth,

View file

@ -2,15 +2,16 @@ package middleware
import ( import (
"net/http" "net/http"
"strings"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
type ( type (
modifyRequest struct { modifyRequest struct {
modifyRequestOpts modifyRequestOpts
m *Middleware m *Middleware
needVarSubstitution bool
} }
// order: set_headers -> add_headers -> hide_headers // order: set_headers -> add_headers -> hide_headers
modifyRequestOpts struct { modifyRequestOpts struct {
@ -24,38 +25,49 @@ var ModifyRequest = &Middleware{withOptions: NewModifyRequest}
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) { func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyRequest) mr := new(modifyRequest)
mrFunc := mr.modifyRequest
if common.IsDebug {
mrFunc = mr.modifyRequestWithTrace
}
mr.m = &Middleware{ mr.m = &Middleware{
impl: mr, impl: mr,
before: Rewrite(mrFunc), before: Rewrite(func(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyHeaders(req, nil, req.Header)
mr.m.AddTraceRequest("after modify request", req)
}),
} }
err := Deserialize(optsRaw, &mr.modifyRequestOpts) err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mr.checkVarSubstitution()
return mr.m, nil return mr.m, nil
} }
func (mr *modifyRequest) modifyRequest(req *Request) { func (mr *modifyRequest) checkVarSubstitution() {
for k, v := range mr.SetHeaders { for _, m := range []map[string]string{mr.SetHeaders, mr.AddHeaders} {
if http.CanonicalHeaderKey(k) == "Host" { for _, v := range m {
req.Host = v if strings.Contains(v, "$") {
mr.needVarSubstitution = true
return
}
} }
req.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
req.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
req.Header.Del(k)
} }
} }
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) { func (mr *modifyRequest) modifyHeaders(req *Request, resp *Response, headers http.Header) {
mr.m.AddTraceRequest("before modify request", req) replaceVars := varReplacerDummy
mr.modifyRequest(req) if mr.needVarSubstitution {
mr.m.AddTraceRequest("after modify request", req) replaceVars = varReplacer(req, resp)
}
for k, v := range mr.SetHeaders {
if strings.ToLower(k) == "host" {
req.Host = replaceVars(v)
}
headers.Set(k, replaceVars(v))
}
for k, v := range mr.AddHeaders {
headers.Add(k, replaceVars(v))
}
for _, k := range mr.HideHeaders {
headers.Del(k)
}
} }

View file

@ -1,17 +1,39 @@
package middleware package middleware
import ( import (
"bytes"
"net/http"
"slices" "slices"
"testing" "testing"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestSetModifyRequest(t *testing.T) { func TestSetModifyRequest(t *testing.T) {
opts := OptionsRaw{ opts := OptionsRaw{
"set_headers": map[string]string{ "set_headers": map[string]string{
"User-Agent": "go-proxy/v0.5.0", "User-Agent": "go-proxy/v0.5.0",
"Host": "test.example.com", "Host": "$upstream_addr",
"X-Test-Req-Method": "$req_method",
"X-Test-Req-Scheme": "$req_scheme",
"X-Test-Req-Host": "$req_host",
"X-Test-Req-Port": "$req_port",
"X-Test-Req-Addr": "$req_addr",
"X-Test-Req-Path": "$req_path",
"X-Test-Req-Query": "$req_query",
"X-Test-Req-Url": "$req_url",
"X-Test-Req-Uri": "$req_uri",
"X-Test-Req-Content-Type": "$req_content_type",
"X-Test-Req-Content-Length": "$req_content_length",
"X-Test-Remote-Addr": "$remote_addr",
"X-Test-Upstream-Scheme": "$upstream_scheme",
"X-Test-Upstream-Host": "$upstream_host",
"X-Test-Upstream-Port": "$upstream_port",
"X-Test-Upstream-Addr": "$upstream_addr",
"X-Test-Upstream-Url": "$upstream_url",
"X-Test-Content-Type": "$header(Content-Type)",
"X-Test-Arg-Arg_1": "$arg(arg_1)",
}, },
"add_headers": map[string]string{"Accept-Encoding": "test-value"}, "add_headers": map[string]string{"Accept-Encoding": "test-value"},
"hide_headers": []string{"Accept"}, "hide_headers": []string{"Accept"},
@ -26,13 +48,44 @@ func TestSetModifyRequest(t *testing.T) {
}) })
t.Run("request_headers", func(t *testing.T) { t.Run("request_headers", func(t *testing.T) {
reqURL := types.MustParseURL("https://my.app/?arg_1=b")
upstreamURL := types.MustParseURL("http://test.example.com")
result, err := newMiddlewareTest(ModifyRequest, &testArgs{ result, err := newMiddlewareTest(ModifyRequest, &testArgs{
middlewareOpt: opts, middlewareOpt: opts,
reqURL: reqURL,
upstreamURL: upstreamURL,
body: bytes.Repeat([]byte("a"), 100),
headers: http.Header{
"Content-Type": []string{"application/json"},
},
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com") ExpectEqual(t, result.RequestHeaders.Get("Host"), "test.example.com")
ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Method"), "GET")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Scheme"), reqURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Host"), reqURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Port"), reqURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Addr"), reqURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Path"), reqURL.Path)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Query"), reqURL.RawQuery)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Url"), reqURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Uri"), reqURL.RequestURI())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Req-Content-Length"), "100")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Remote-Addr"), result.RemoteAddr)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Scheme"), upstreamURL.Scheme)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Host"), upstreamURL.Hostname())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Port"), upstreamURL.Port())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Addr"), upstreamURL.Host)
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Upstream-Url"), upstreamURL.String())
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Content-Type"), "application/json")
ExpectEqual(t, result.RequestHeaders.Get("X-Test-Arg-Arg_1"), "b")
}) })
} }

View file

@ -3,52 +3,28 @@ package middleware
import ( import (
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
) )
type ( type modifyResponse = modifyRequest
modifyResponse struct {
modifyResponseOpts
m *Middleware
}
// order: set_headers -> add_headers -> hide_headers
modifyResponseOpts = modifyRequestOpts
)
var ModifyResponse = &Middleware{withOptions: NewModifyResponse} var ModifyResponse = &Middleware{withOptions: NewModifyResponse}
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) { func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.Error) {
mr := new(modifyResponse) mr := new(modifyResponse)
mr.m = &Middleware{impl: mr} mr.m = &Middleware{
if common.IsDebug { impl: mr,
mr.m.modifyResponse = mr.modifyResponseWithTrace modifyResponse: func(resp *http.Response) error {
} else { mr.m.AddTraceResponse("before modify response", resp)
mr.m.modifyResponse = mr.modifyResponse mr.modifyHeaders(resp.Request, resp, resp.Header)
mr.m.AddTraceResponse("after modify response", resp)
return nil
},
} }
err := Deserialize(optsRaw, &mr.modifyResponseOpts) err := Deserialize(optsRaw, &mr.modifyRequestOpts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mr.checkVarSubstitution()
return mr.m, nil return mr.m, nil
} }
func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
for k, v := range mr.SetHeaders {
resp.Header.Set(k, v)
}
for k, v := range mr.AddHeaders {
resp.Header.Add(k, v)
}
for _, k := range mr.HideHeaders {
resp.Header.Del(k)
}
return nil
}
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
mr.m.AddTraceResponse("before modify response", resp)
err := mr.modifyResponse(resp)
mr.m.AddTraceResponse("after modify response", resp)
return err
}

View file

@ -5,21 +5,22 @@ import (
"testing" "testing"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/types"
. "github.com/yusing/go-proxy/internal/utils/testing" . "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestRedirectToHTTPs(t *testing.T) { func TestRedirectToHTTPs(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "http", reqURL: types.MustParseURL("http://example.com"),
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect) ExpectEqual(t, result.ResponseStatus, http.StatusTemporaryRedirect)
ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://"+testHost+":"+common.ProxyHTTPSPort) ExpectEqual(t, result.ResponseHeaders.Get("Location"), "https://example.com:"+common.ProxyHTTPSPort)
} }
func TestNoRedirect(t *testing.T) { func TestNoRedirect(t *testing.T) {
result, err := newMiddlewareTest(RedirectHTTP, &testArgs{ result, err := newMiddlewareTest(RedirectHTTP, &testArgs{
scheme: "https", reqURL: types.MustParseURL("https://example.com"),
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, result.ResponseStatus, http.StatusOK) ExpectEqual(t, result.ResponseStatus, http.StatusOK)

View file

@ -7,7 +7,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -19,8 +18,6 @@ import (
var testHeadersRaw []byte var testHeadersRaw []byte
var testHeaders http.Header var testHeaders http.Header
const testHost = "example.com"
func init() { func init() {
if !common.IsTest { if !common.IsTest {
return return
@ -67,16 +64,16 @@ type TestResult struct {
type testArgs struct { type testArgs struct {
middlewareOpt OptionsRaw middlewareOpt OptionsRaw
proxyURL string reqURL types.URL
upstreamURL types.URL
body []byte body []byte
scheme string realRoundTrip bool
headers http.Header
} }
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) { func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.Error) {
var body io.Reader var body io.Reader
var rr requestRecorder var rr requestRecorder
var proxyURL *url.URL
var requestTarget string
var err error var err error
if args == nil { if args == nil {
@ -87,34 +84,24 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
body = bytes.NewReader(args.body) body = bytes.NewReader(args.body)
} }
switch args.scheme { if args.reqURL.Nil() {
case "": args.reqURL = E.Must(types.ParseURL("https://example.com"))
fallthrough
case "http":
requestTarget = "http://" + testHost
case "https":
requestTarget = "https://" + testHost
default:
panic("typo?")
} }
req := httptest.NewRequest(http.MethodGet, requestTarget, body) req := httptest.NewRequest(http.MethodGet, args.reqURL.String(), body)
for k, v := range args.headers {
req.Header[k] = v
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
if args.scheme == "https" && req.TLS == nil { if args.upstreamURL.Nil() {
panic("bug occurred") args.upstreamURL = E.Must(types.ParseURL("https://10.0.0.1:8443")) // dummy url, no actual effect
} }
if args.proxyURL != "" { if args.realRoundTrip {
proxyURL, err = url.Parse(args.proxyURL)
if err != nil {
return nil, E.From(err)
}
rr.parent = http.DefaultTransport rr.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
} }
rp := gphttp.NewReverseProxy(middleware.name, types.NewURL(proxyURL), &rr) rp := gphttp.NewReverseProxy(middleware.name, args.upstreamURL, &rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil { if setOptErr != nil {
return nil, setOptErr return nil, setOptErr

View file

@ -97,10 +97,16 @@ func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
} }
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace { func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithRequest(req) return m.AddTracef("%s", msg).WithRequest(req)
} }
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace { func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
if !m.trace {
return nil
}
return m.AddTracef("%s", msg).WithResponse(resp) return m.AddTracef("%s", msg).WithResponse(resp)
} }

View file

@ -0,0 +1,106 @@
package middleware
import (
"net"
"net/http"
"regexp"
"strconv"
"strings"
gphttp "github.com/yusing/go-proxy/internal/net/http"
)
type varReplaceFunc func(string) string
var (
reArg = regexp.MustCompile(`\$arg\([\w-_]+\)`)
reHeader = regexp.MustCompile(`\$header\([\w-]+\)`)
reStatic = regexp.MustCompile(`\$[\w_]+`)
)
func varSubsMap(req *Request, resp *Response) map[string]func() string {
reqHost, reqPort, err := net.SplitHostPort(req.Host)
if err != nil {
reqHost = req.Host
}
reqAddr := reqHost
if reqPort != "" {
reqAddr += ":" + reqPort
}
pairs := map[string]func() string{
"$req_method": func() string { return req.Method },
"$req_scheme": func() string { return req.URL.Scheme },
"$req_host": func() string { return reqHost },
"$req_port": func() string { return reqPort },
"$req_addr": func() string { return reqAddr },
"$req_path": func() string { return req.URL.Path },
"$req_query": func() string { return req.URL.RawQuery },
"$req_url": func() string { return req.URL.String() },
"$req_uri": req.URL.RequestURI,
"$req_content_type": func() string { return req.Header.Get("Content-Type") },
"$req_content_length": func() string { return strconv.FormatInt(req.ContentLength, 10) },
"$remote_addr": func() string { return req.RemoteAddr },
}
if resp != nil {
pairs["$resp_content_type"] = func() string { return resp.Header.Get("Content-Type") }
pairs["$resp_content_length"] = func() string { return resp.Header.Get("Content-Length") }
pairs["$status_code"] = func() string { return strconv.Itoa(resp.StatusCode) }
}
upScheme := req.Header.Get(gphttp.HeaderUpstreamScheme)
if upScheme == "" {
return pairs
}
upHost := req.Header.Get(gphttp.HeaderUpstreamHost)
upPort := req.Header.Get(gphttp.HeaderUpstreamPort)
upAddr := upHost
if upPort != "" {
upAddr += ":" + upPort
}
upURL := upScheme + "://" + upAddr
pairs["$upstream_scheme"] = func() string { return upScheme }
pairs["$upstream_host"] = func() string { return upHost }
pairs["$upstream_port"] = func() string { return upPort }
pairs["$upstream_addr"] = func() string { return upAddr }
pairs["$upstream_url"] = func() string { return upURL }
return pairs
}
func varReplacer(req *Request, resp *Response) varReplaceFunc {
pairs := varSubsMap(req, resp)
return func(s string) string {
// Replace query parameters
s = reArg.ReplaceAllStringFunc(s, func(match string) string {
name := match[5 : len(match)-1]
for k, v := range req.URL.Query() {
if strings.EqualFold(k, name) {
return v[0]
}
}
return ""
})
// Replace headers
s = reHeader.ReplaceAllStringFunc(s, func(match string) string {
header := http.CanonicalHeaderKey(match[8 : len(match)-1])
return req.Header.Get(header)
})
// Replace static variables
return reStatic.ReplaceAllStringFunc(s, func(match string) string {
if fn, ok := pairs[match]; ok {
return fn()
}
return match
})
}
}
func varReplacerDummy(s string) string {
return s
}

View file

@ -264,6 +264,9 @@ func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// req.Header.Set(HeaderUpstreamScheme, p.TargetURL.Scheme)
// req.Header.Set(HeaderUpstreamHost, p.TargetURL.Hostname())
// req.Header.Set(HeaderUpstreamPort, p.TargetURL.Port())
p.HandlerFunc(rw, req) p.HandlerFunc(rw, req)
} }

View file

@ -8,6 +8,14 @@ type URL struct {
*urlPkg.URL *urlPkg.URL
} }
func MustParseURL(url string) URL {
u, err := ParseURL(url)
if err != nil {
panic(err)
}
return u
}
func ParseURL(url string) (URL, error) { func ParseURL(url string) (URL, error) {
u, err := urlPkg.Parse(url) u, err := urlPkg.Parse(url)
if err != nil { if err != nil {
@ -20,6 +28,10 @@ func NewURL(url *urlPkg.URL) URL {
return URL{url} return URL{url}
} }
func (u URL) Nil() bool {
return u.URL == nil
}
func (u URL) String() string { func (u URL) String() string {
if u.URL == nil { if u.URL == nil {
return "nil" return "nil"

View file

@ -101,12 +101,21 @@ func TestWebhookBody(t *testing.T) {
}) })
ExpectNoError(t, err) ExpectNoError(t, err)
var body map[string][]map[string]any var body struct {
Embeds []struct {
Title string `json:"title"`
Fields []struct {
Name string `json:"name"`
Value string `json:"value"`
} `json:"fields"`
} `json:"embeds"`
}
err = json.NewDecoder(bodyReader).Decode(&body) err = json.NewDecoder(bodyReader).Decode(&body)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqual(t, body["embeds"][0]["title"], "abc") ExpectEqual(t, body.Embeds[0].Title, "abc")
fields := ExpectType[[]map[string]any](t, body["embeds"][0]["fields"]) fields := body.Embeds[0].Fields
ExpectEqual(t, fields[0]["name"], "foo") ExpectEqual(t, fields[0].Name, "foo")
ExpectEqual(t, fields[0]["value"], "bar") ExpectEqual(t, fields[0].Value, "bar")
} }

View file

@ -109,8 +109,7 @@ func ExpectType[T any](t *testing.T, got any) (_ T) {
t.Helper() t.Helper()
_, ok := got.(T) _, ok := got.(T)
if !ok { if !ok {
t.Fatalf("expected type %s, got %s", reflect.TypeFor[T](), reflect.TypeOf(got).Elem()) t.Fatalf("expected type %s, got %T", reflect.TypeFor[T](), got)
return
} }
return got.(T) return got.(T)
} }