diff --git a/internal/utils/io.go b/internal/utils/io.go index a470edb..c8423f2 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -132,6 +132,17 @@ const ( // This is a copy of io.Copy with context and HTTP flusher handling // Author: yusing . func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.Reader.(io.WriterTo); ok { + _, err = wt.WriteTo(dst) + return + } + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rf, ok := dst.Writer.(io.ReaderFrom); ok { + _, err = rf.ReadFrom(src) + return + } var buf []byte if l, ok := src.Reader.(*io.LimitedReader); ok { size := copyBufSize @@ -142,7 +153,7 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { size = int(l.N) } } - buf = make([]byte, size) + buf = make([]byte, 0, size) } else { buf = copyBufPool.Get().([]byte) defer copyBufPool.Put(buf) @@ -179,47 +190,40 @@ func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { flusher := getHttpFlusher(dst.Writer) canFlush := flusher != nil for { - select { - case <-src.ctx.Done(): - return src.ctx.Err() - case <-dst.ctx.Done(): - return dst.ctx.Err() - default: - nr, er := src.Reader.Read(buf) - if nr > 0 { - nw, ew := dst.Writer.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errors.New("invalid write result") - } - } - if ew != nil { - err = ew - return - } - if nr != nw { - err = io.ErrShortWrite - return - } - if canFlush { - err = flusher.Flush() - if err != nil { - if errors.Is(err, http.ErrNotSupported) { - canFlush = false - err = nil - } else { - return err - } - } + nr, er := src.Reader.Read(buf[:copyBufSize]) + if nr > 0 { + nw, ew := dst.Writer.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write result") } } - if er != nil { - if er != io.EOF { - err = er - } + if ew != nil { + err = ew return } + if nr != nw { + err = io.ErrShortWrite + return + } + if canFlush { + err = flusher.Flush() + if err != nil { + if errors.Is(err, http.ErrNotSupported) { + canFlush = false + err = nil + } else { + return err + } + } + } + } + if er != nil { + if er != io.EOF { + err = er + } + return } } }