diff --git a/internal/net/gphttp/middleware/middlewares.go b/internal/net/gphttp/middleware/middlewares.go index 8ae4ee9..0e98bc4 100644 --- a/internal/net/gphttp/middleware/middlewares.go +++ b/internal/net/gphttp/middleware/middlewares.go @@ -24,6 +24,8 @@ var allMiddlewares = map[string]*Middleware{ "setxforwarded": SetXForwarded, "hidexforwarded": HideXForwarded, + "modifyhtml": ModifyHTML, + "errorpage": CustomErrorPage, "customerrorpage": CustomErrorPage, diff --git a/internal/net/gphttp/middleware/modify_html.go b/internal/net/gphttp/middleware/modify_html.go new file mode 100644 index 0000000..a397949 --- /dev/null +++ b/internal/net/gphttp/middleware/modify_html.go @@ -0,0 +1,106 @@ +package middleware + +import ( + "bytes" + "io" + "net/http" + "strconv" + + "github.com/PuerkitoBio/goquery" + "github.com/rs/zerolog/log" + gphttp "github.com/yusing/go-proxy/internal/net/gphttp" + "golang.org/x/net/html" +) + +type modifyHTML struct { + Target string // css selector + HTML string // html to inject + Replace bool // replace the target element with the new html instead of appending it +} + +var ModifyHTML = NewMiddleware[modifyHTML]() + +// modifyResponse implements ResponseModifier. +func (m *modifyHTML) modifyResponse(resp *http.Response) error { + // including text/html and application/xhtml+xml + if !gphttp.GetContentType(resp.Header).IsHTML() { + return nil + } + + content, err := io.ReadAll(resp.Body) + if err != nil { + resp.Body.Close() + return err + } + resp.Body.Close() + + doc, err := goquery.NewDocumentFromReader(bytes.NewReader(content)) + if err != nil { + // invalid html, restore the original body + resp.Body = io.NopCloser(bytes.NewReader(content)) + log.Err(err).Str("url", fullURL(resp.Request)).Msg("invalid html found") + return nil + } + + ele := doc.Find(m.Target) + if ele.Length() == 0 { + // no target found, restore the original body + resp.Body = io.NopCloser(bytes.NewReader(content)) + return nil + } + + if m.Replace { + // replace all matching elements + ele.ReplaceWithHtml(m.HTML) + } else { + // append to the first matching element + ele.First().AppendHtml(m.HTML) + } + + h, err := buildHTML(doc) + if err != nil { + return err + } + resp.ContentLength = int64(len(h)) + resp.Header.Set("Content-Length", strconv.Itoa(len(h))) + resp.Body = io.NopCloser(bytes.NewReader(h)) + return nil +} + +// copied and modified from (*goquery.Selection).Html() +func buildHTML(s *goquery.Document) (ret []byte, err error) { + var buf bytes.Buffer + + // Merge all head nodes into one + headNodes := s.Find("head") + if headNodes.Length() > 1 { + // Get the first head node to merge everything into + firstHead := headNodes.First() + + // Merge content from all other head nodes into the first one + headNodes.Slice(1, headNodes.Length()).Each(func(i int, otherHead *goquery.Selection) { + // Move all children from other head nodes to the first head + otherHead.Children().Each(func(j int, child *goquery.Selection) { + firstHead.AppendSelection(child) + }) + }) + + // Remove the duplicate head nodes (keep only the first one) + headNodes.Slice(1, headNodes.Length()).Remove() + } + + if len(s.Nodes) > 0 { + for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling { + err = html.Render(&buf, c) + if err != nil { + return + } + } + ret = buf.Bytes() + } + return +} + +func fullURL(req *http.Request) string { + return req.Host + req.RequestURI +} diff --git a/internal/net/gphttp/middleware/modify_html_test.go b/internal/net/gphttp/middleware/modify_html_test.go new file mode 100644 index 0000000..fa36fd7 --- /dev/null +++ b/internal/net/gphttp/middleware/modify_html_test.go @@ -0,0 +1,557 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" + "testing" + + expect "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestInjectCSS(t *testing.T) { + opts := OptionsRaw{ + "target": "head", + "html": "", + } + result, err := newMiddlewareTest(ModifyHTML, &testArgs{ + middlewareOpt: opts, + respHeaders: http.Header{ + "Content-Type": []string{"text/html; charset=utf-8"}, + }, + respBody: []byte(` + +
+Injected content
", + } + result, err := newMiddlewareTest(ModifyHTML, &testArgs{ + middlewareOpt: opts, + respHeaders: http.Header{ + "Content-Type": []string{"text/html"}, + }, + respBody: []byte(` + + + +Injected content
Some content here.
+Some content here.
+Some content`), + }) + expect.NoError(t, err) + // Should handle malformed HTML gracefully + expect.True(t, strings.Contains(string(result.Data), "Valid injection"), "Should inject content even with malformed HTML") +} + +func TestInjectHTML_ContentTypes(t *testing.T) { + testCases := []struct { + name string + contentType string + shouldModify bool + }{ + {"HTML with charset", "text/html; charset=utf-8", true}, + {"Plain HTML", "text/html", true}, + {"XHTML", "application/xhtml+xml", true}, + {"JSON", "application/json", false}, + {"Plain text", "text/plain", false}, + {"JavaScript", "application/javascript", false}, + {"CSS", "text/css", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := OptionsRaw{ + "target": "body", + "html": "
More original content
+ + + `), + }) + expect.NoError(t, err) + expect.Equal(t, removeTabsAndNewlines(result.Data), removeTabsAndNewlines(` + + +Replaced content
", + "replace": true, + } + result, err := newMiddlewareTest(ModifyHTML, &testArgs{ + middlewareOpt: opts, + respHeaders: http.Header{ + "Content-Type": []string{"text/html"}, + }, + respBody: []byte(` + + + +Replaced content
+Replaced content
+ + + ` + expect.Equal(t, removeTabsAndNewlines(result.Data), removeTabsAndNewlines(expectedContent)) +} + +func TestInjectHTML_ReplaceComplexHTML(t *testing.T) { + opts := OptionsRaw{ + "target": "main", + "html": `This replaces the entire main content.
Original content that will be replaced.
+This replaces the entire main content.