refactor(io): enhance HTTP flusher handling

This commit is contained in:
yusing 2025-05-31 13:54:50 +08:00
parent 6f9bb410f5
commit 57200bc1e9

View file

@ -106,13 +106,38 @@ func (p BidirectionalPipe) Start() error {
return errors.Join(srcErr, dstErr)
}
type httpFlusher interface {
Flush() error
type flushErrorInterface interface {
FlushError() error
}
func getHTTPFlusher(dst io.Writer) httpFlusher {
type flusherWrapper struct {
rw http.Flusher
}
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
func (f *flusherWrapper) FlushError() error {
f.rw.Flush()
return nil
}
func getHTTPFlusher(dst io.Writer) flushErrorInterface {
// pre-unwrap the flusher to prevent unwrap and check in every loop
if rw, ok := dst.(http.ResponseWriter); ok {
return http.NewResponseController(rw)
for {
switch t := rw.(type) {
case flushErrorInterface:
return t
case http.Flusher:
return &flusherWrapper{rw: t}
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil
}
}
}
return nil
}
@ -158,7 +183,6 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
}()
}
flusher := getHTTPFlusher(dst.Writer)
canFlush := flusher != nil
for {
nr, er := src.Reader.Read(buf)
if nr > 0 {
@ -177,15 +201,10 @@ func CopyClose(dst *ContextWriter, src *ContextReader, sizeHint int) (err error)
err = io.ErrShortWrite
return
}
if canFlush {
err = flusher.Flush()
if flusher != nil {
err = flusher.FlushError()
if err != nil {
if errors.Is(err, http.ErrNotSupported) {
canFlush = false
err = nil
} else {
return err
}
return err
}
}
}