fix: limit redirect count when parsing html for favicon, fix url sanitize method

This commit is contained in:
yusing 2025-03-29 09:35:12 +08:00
parent d2e2086540
commit 146e7781be
2 changed files with 18 additions and 6 deletions

View file

@ -7,7 +7,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"path"
"strings" "strings"
"time" "time"
@ -43,6 +42,10 @@ func (res *fetchResult) ContentType() string {
return res.contentType return res.contentType
} }
const (
MaxRedirectDepth = 5
)
// GetFavIcon returns the favicon of the route // GetFavIcon returns the favicon of the route
// //
// Returns: // Returns:
@ -195,7 +198,7 @@ func findIcon(r route.HTTPRoute, req *http.Request, uri string) *fetchResult {
} }
if !result.OK() { if !result.OK() {
// fallback to parse html // fallback to parse html
result = findIconSlow(r, req, uri) result = findIconSlow(r, req, uri, 0)
} }
if result.OK() { if result.OK() {
storeIconCache(key, result.icon) storeIconCache(key, result.icon)
@ -203,7 +206,7 @@ func findIcon(r route.HTTPRoute, req *http.Request, uri string) *fetchResult {
return result return result
} }
func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult { func findIconSlow(r route.HTTPRoute, req *http.Request, uri string, depth int) *fetchResult {
ctx, cancel := context.WithTimeoutCause(req.Context(), 3*time.Second, errors.New("favicon request timeout")) ctx, cancel := context.WithTimeoutCause(req.Context(), 3*time.Second, errors.New("favicon request timeout"))
defer cancel() defer cancel()
newReq := req.WithContext(ctx) newReq := req.WithContext(ctx)
@ -229,11 +232,14 @@ func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult
return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "connection error"} return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "connection error"}
default: default:
if loc := c.Header().Get("Location"); loc != "" { if loc := c.Header().Get("Location"); loc != "" {
if depth > MaxRedirectDepth {
return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "too many redirects"}
}
loc = strutils.SanitizeURI(loc) loc = strutils.SanitizeURI(loc)
if loc == "/" || loc == newReq.URL.Path { if loc == "/" || loc == newReq.URL.Path {
return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "circular redirect"} return &fetchResult{statusCode: http.StatusBadGateway, errMsg: "circular redirect"}
} }
return findIconSlow(r, req, loc) return findIconSlow(r, req, loc, depth+1)
} }
} }
return &fetchResult{statusCode: c.status, errMsg: "upstream error: " + string(c.data)} return &fetchResult{statusCode: c.status, errMsg: "upstream error: " + string(c.data)}
@ -273,6 +279,6 @@ func findIconSlow(r route.HTTPRoute, req *http.Request, uri string) *fetchResult
case strings.HasPrefix(href, "http://"), strings.HasPrefix(href, "https://"): case strings.HasPrefix(href, "http://"), strings.HasPrefix(href, "https://"):
return fetchIconAbsolute(href) return fetchIconAbsolute(href)
default: default:
return findIconSlow(r, req, path.Clean(href)) return findIconSlow(r, req, href, 0)
} }
} }

View file

@ -1,6 +1,9 @@
package strutils package strutils
import "path" import (
"path"
"strings"
)
// SanitizeURI sanitizes a URI reference to ensure it is safe // SanitizeURI sanitizes a URI reference to ensure it is safe
// It disallows URLs beginning with // or /\ as absolute URLs, // It disallows URLs beginning with // or /\ as absolute URLs,
@ -10,6 +13,9 @@ func SanitizeURI(uri string) string {
if uri == "" { if uri == "" {
return "/" return "/"
} }
if strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") {
return uri
}
if uri[0] != '/' { if uri[0] != '/' {
uri = "/" + uri uri = "/" + uri
} }