diff --git a/internal/api/v1/favicon/favicon.go b/internal/api/v1/favicon/favicon.go index ccbb9a3..ce38b1e 100644 --- a/internal/api/v1/favicon/favicon.go +++ b/internal/api/v1/favicon/favicon.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "net/url" - "path" "strings" "time" @@ -43,6 +42,10 @@ func (res *fetchResult) ContentType() string { return res.contentType } +const ( + MaxRedirectDepth = 5 +) + // GetFavIcon returns the favicon of the route // // Returns: @@ -195,7 +198,7 @@ func findIcon(r route.HTTPRoute, req *http.Request, uri string) *fetchResult { } if !result.OK() { // fallback to parse html - result = findIconSlow(r, req, uri) + result = findIconSlow(r, req, uri, 0) } if result.OK() { storeIconCache(key, result.icon) @@ -203,7 +206,7 @@ func findIcon(r route.HTTPRoute, req *http.Request, uri string) *fetchResult { return result } -func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult { +func findIconSlow(r route.HTTPRoute, req *http.Request, uri string, depth int) *fetchResult { ctx, cancel := context.WithTimeoutCause(req.Context(), 3*time.Second, errors.New("favicon request timeout")) defer cancel() newReq := req.WithContext(ctx) @@ -229,11 +232,14 @@ func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "connection error"} default: if loc := c.Header().Get("Location"); loc != "" { + if depth > MaxRedirectDepth { + return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "too many redirects"} + } loc = strutils.SanitizeURI(loc) if loc == "/" || loc == newReq.URL.Path { return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "circular redirect"} } - return findIconSlow(r, req, loc) + return findIconSlow(r, req, loc, depth+1) } } return &fetchResult{statusCode: c.status, errMsg: "upstream error: " + string(c.data)} @@ -273,6 +279,6 @@ func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult case strings.HasPrefix(href, "http://"), strings.HasPrefix(href, "https://"): return fetchIconAbsolute(href) default: - return findIconSlow(r, req, path.Clean(href)) + return findIconSlow(r, req, href, 0) } } diff --git a/internal/utils/strutils/url.go b/internal/utils/strutils/url.go index 4f913a1..b587c24 100644 --- a/internal/utils/strutils/url.go +++ b/internal/utils/strutils/url.go @@ -1,6 +1,9 @@ package strutils -import "path" +import ( + "path" + "strings" +) // SanitizeURI sanitizes a URI reference to ensure it is safe // It disallows URLs beginning with // or /\ as absolute URLs, @@ -10,6 +13,9 @@ func SanitizeURI(uri string) string { if uri == "" { return "/" } + if strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") { + return uri + } if uri[0] != '/' { uri = "/" + uri }