From 57200bc1e9f176640647fc0a7664168e9316d927 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 31 May 2025 13:54:50 +0800 Subject: [PATCH] refactor(io): enhance HTTP flusher handling --- internal/utils/io.go | 45 +++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/internal/utils/io.go b/internal/utils/io.go index ee40600..2e1702d 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -106,13 +106,38 @@ func (p BidirectionalPipe) Start() error { return errors.Join(srcErr, dstErr) } -type httpFlusher interface { - Flush() error +type flushErrorInterface interface { + FlushError() error } -func getHTTPFlusher(dst io.Writer) httpFlusher { +type flusherWrapper struct { + rw http.Flusher +} + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +func (f *flusherWrapper) FlushError() error { + f.rw.Flush() + return nil +} + +func getHTTPFlusher(dst io.Writer) flushErrorInterface { + // pre-unwrap the flusher to prevent unwrap and check in every loop if rw, ok := dst.(http.ResponseWriter); ok { - return http.NewResponseController(rw) + for { + switch t := rw.(type) { + case flushErrorInterface: + return t + case http.Flusher: + return &flusherWrapper{rw: t} + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil + } + } } return nil } @@ -158,7 +183,6 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error) }() } flusher := getHTTPFlusher(dst.Writer) - canFlush := flusher != nil for { nr, er := src.Reader.Read(buf) if nr > 0 { @@ -177,15 +201,10 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error) err = io.ErrShortWrite return } - if canFlush { - err = flusher.Flush() + if flusher != nil { + err = flusher.FlushError() if err != nil { - if errors.Is(err, http.ErrNotSupported) { - canFlush = false - err = nil - } else { - return err - } + return err } } }