refactor(io): enhance HTTP flusher handling

This commit is contained in:
yusing 2025-05-31 13:54:50 +08:00
parent 6f9bb410f5
commit 57200bc1e9

View file

@ -106,13 +106,38 @@ func (p BidirectionalPipe) Start() error {
return errors.Join(srcErr, dstErr) return errors.Join(srcErr, dstErr)
} }
type httpFlusher interface { type flushErrorInterface interface {
Flush() error FlushError() error
} }
func getHTTPFlusher(dst io.Writer) httpFlusher { type flusherWrapper struct {
rw http.Flusher
}
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
func (f *flusherWrapper) FlushError() error {
f.rw.Flush()
return nil
}
func getHTTPFlusher(dst io.Writer) flushErrorInterface {
// pre-unwrap the flusher to prevent unwrap and check in every loop
if rw, ok := dst.(http.ResponseWriter); ok { if rw, ok := dst.(http.ResponseWriter); ok {
return http.NewResponseController(rw) for {
switch t := rw.(type) {
case flushErrorInterface:
return t
case http.Flusher:
return &flusherWrapper{rw: t}
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil
}
}
} }
return nil return nil
} }
@ -158,7 +183,6 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
}() }()
} }
flusher := getHTTPFlusher(dst.Writer) flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil
for { for {
nr, er := src.Reader.Read(buf) nr, er := src.Reader.Read(buf)
if nr > 0 { if nr > 0 {
@ -177,15 +201,10 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
err = io.ErrShortWrite err = io.ErrShortWrite
return return
} }
if canFlush { if flusher != nil {
err = flusher.Flush() err = flusher.FlushError()
if err != nil { if err != nil {
if errors.Is(err, http.ErrNotSupported) { return err
canFlush = false
err = nil
} else {
return err
}
} }
} }
} }