diff --git a/internal/common/constants.go b/internal/common/constants.go index cf6d0f2..17e6190 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -4,12 +4,6 @@ import ( "time" ) -const ( - ConnectionTimeout = 5 * time.Second - DialTimeout = 3 * time.Second - KeepAlive = 60 * time.Second -) - // file, folder structure const ( @@ -30,6 +24,8 @@ const ( ComposeExampleFileName = "compose.example.yml" ErrorPagesBasePath = "error_pages" + + AgentCertsBasePath = "certs" ) var RequiredDirectories = []string{ @@ -48,5 +44,3 @@ const ( StopTimeoutDefault = "30s" StopMethodDefault = "stop" ) - -const HeaderCheckRedirect = "X-Goproxy-Check-Redirect" diff --git a/internal/net/http/header_utils.go b/internal/net/gphttp/httpheaders/utils.go similarity index 91% rename from internal/net/http/header_utils.go rename to internal/net/gphttp/httpheaders/utils.go index db8c78f..00bed76 100644 --- a/internal/net/http/header_utils.go +++ b/internal/net/gphttp/httpheaders/utils.go @@ -1,4 +1,4 @@ -package http +package httpheaders import ( "net/http" @@ -17,13 +17,15 @@ const ( HeaderXForwardedURI = "X-Forwarded-Uri" HeaderXRealIP = "X-Real-IP" - HeaderUpstreamName = "X-GoDoxy-Upstream-Name" - HeaderUpstreamScheme = "X-GoDoxy-Upstream-Scheme" - HeaderUpstreamHost = "X-GoDoxy-Upstream-Host" - HeaderUpstreamPort = "X-GoDoxy-Upstream-Port" - HeaderContentType = "Content-Type" HeaderContentLength = "Content-Length" + + HeaderUpstreamName = "X-Godoxy-Upstream-Name" + HeaderUpstreamScheme = "X-Godoxy-Upstream-Scheme" + HeaderUpstreamHost = "X-Godoxy-Upstream-Host" + HeaderUpstreamPort = "X-Godoxy-Upstream-Port" + + HeaderGoDoxyCheckRedirect = "X-Godoxy-Check-Redirect" ) // Hop-by-hop headers. These are removed when sent to the backend. diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go similarity index 94% rename from internal/net/http/reverseproxy/reverse_proxy_mod.go rename to internal/net/gphttp/reverseproxy/reverse_proxy_mod.go index aeeb31d..49988f0 100644 --- a/internal/net/http/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/gphttp/reverseproxy/reverse_proxy_mod.go @@ -26,8 +26,8 @@ import ( "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/logging" - gphttp "github.com/yusing/go-proxy/internal/net/http" - "github.com/yusing/go-proxy/internal/net/http/accesslog" + "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/net/types" U "github.com/yusing/go-proxy/internal/utils" "golang.org/x/net/http/httpguts" @@ -168,7 +168,7 @@ func copyHeader(dst, src http.Header) { } func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err error, writeHeader bool) { - reqURL := r.Host + r.RequestURI + reqURL := r.Host + r.URL.Path switch { case errors.Is(err, context.Canceled), errors.Is(err, io.EOF), @@ -266,14 +266,14 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { p.rewriteRequestURL(outreq) outreq.Close = false - reqUpType := gphttp.UpgradeType(outreq.Header) + reqUpType := httpheaders.UpgradeType(outreq.Header) if !IsPrint(reqUpType) { p.errorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType), true) return } req.Header.Del("Forwarded") - gphttp.RemoveHopByHopHeaders(outreq.Header) + httpheaders.RemoveHopByHopHeaders(outreq.Header) // Issue 21096: tell backend applications that care about trailer support // that we support trailers. (We do, but we don't go out of our way to @@ -298,7 +298,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - prior, ok := outreq.Header[gphttp.HeaderXForwardedFor] + prior, ok := outreq.Header[httpheaders.HeaderXForwardedFor] omit := ok && prior == nil // Issue 38079: nil now means don't populate the header xff, _, err := net.SplitHostPort(req.RemoteAddr) @@ -309,7 +309,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { xff = strings.Join(prior, ", ") + ", " + xff } if !omit { - outreq.Header.Set(gphttp.HeaderXForwardedFor, xff) + outreq.Header.Set(httpheaders.HeaderXForwardedFor, xff) } var reqScheme string @@ -319,10 +319,10 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { reqScheme = "http" } - outreq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method) - outreq.Header.Set(gphttp.HeaderXForwardedProto, reqScheme) - outreq.Header.Set(gphttp.HeaderXForwardedHost, req.Host) - outreq.Header.Set(gphttp.HeaderXForwardedURI, req.RequestURI) + outreq.Header.Set(httpheaders.HeaderXForwardedMethod, req.Method) + outreq.Header.Set(httpheaders.HeaderXForwardedProto, reqScheme) + outreq.Header.Set(httpheaders.HeaderXForwardedHost, req.Host) + outreq.Header.Set(httpheaders.HeaderXForwardedURI, req.RequestURI) if _, ok := outreq.Header["User-Agent"]; !ok { // If the outbound request doesn't have a User-Agent header set, @@ -389,7 +389,7 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { return } - gphttp.RemoveHopByHopHeaders(res.Header) + httpheaders.RemoveHopByHopHeaders(res.Header) if !p.modifyResponse(rw, res, req, outreq) { return @@ -410,15 +410,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - _, err = io.Copy(rw, res.Body) + err = U.CopyCloseWithContext(ctx, rw, res.Body) // close now, instead of defer, to populate res.Trailer if err != nil { if !errors.Is(err, context.Canceled) { - p.errorHandler(rw, req, err, true) + p.errorHandler(rw, req, err, false) } - res.Body.Close() return } - res.Body.Close() // close now, instead of defer, to populate res.Trailer if len(res.Trailer) > 0 { // Force chunking if we saw a response trailer. @@ -460,8 +458,8 @@ func cleanWebsocketHeaders(req *http.Request) { } func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { - reqUpType := gphttp.UpgradeType(req.Header) - resUpType := gphttp.UpgradeType(res.Header) + reqUpType := httpheaders.UpgradeType(req.Header) + resUpType := httpheaders.UpgradeType(res.Header) if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. p.errorHandler(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType), true) return diff --git a/internal/utils/io.go b/internal/utils/io.go index 1cae52f..f24da03 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "net/http" "sync" "syscall" @@ -37,6 +38,14 @@ type ( } ) +func NewContextReader(ctx context.Context, r io.Reader) *ContextReader { + return &ContextReader{ctx: ctx, Reader: r} +} + +func NewContextWriter(ctx context.Context, w io.Writer) *ContextWriter { + return &ContextWriter{ctx: ctx, Writer: w} +} + func (r *ContextReader) Read(p []byte) (int, error) { select { case <-r.ctx.Done(): @@ -63,7 +72,7 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { } func (p *Pipe) Start() (err error) { - err = Copy(&p.w, &p.r) + err = CopyClose(&p.w, &p.r) switch { case // NOTE: ignoring broken pipe and connection reset by peer @@ -97,55 +106,123 @@ func (p BidirectionalPipe) Start() gperr.Error { return b.Error() } +var copyBufPool = sync.Pool{ + New: func() any { + return make([]byte, copyBufSize) + }, +} + +type httpFlusher interface { + Flush() error +} + +func getHttpFlusher(dst io.Writer) httpFlusher { + if rw, ok := dst.(http.ResponseWriter); ok { + return http.NewResponseController(rw) + } + return nil +} + +const ( + copyBufSize = 32 * 1024 +) + +var copyBufPool = sync.Pool{ + New: func() any { + return make([]byte, copyBufSize) + }, +} + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// This is a copy of io.Copy with context handling +// This is a copy of io.Copy with context and HTTP flusher handling // Author: yusing . -func Copy(dst *ContextWriter, src *ContextReader) (err error) { - size := 32 * 1024 - if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N { - if l.N < 1 { - size = 1 +func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { + var buf []byte + if l, ok := src.Reader.(*io.LimitedReader); ok { + size := copyBufSize + if int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + buf = make([]byte, 0, size) + } else { + buf = copyBufPool.Get().([]byte) + defer copyBufPool.Put(buf) + } + // close both as soon as one of them is done + wCloser, wCanClose := dst.Writer.(io.Closer) + rCloser, rCanClose := src.Reader.(io.Closer) + if wCanClose || rCanClose { + if src.ctx == dst.ctx { + go func() { + <-src.ctx.Done() + if wCanClose { + wCloser.Close() + } + if rCanClose { + rCloser.Close() + } + }() } else { - size = int(l.N) + if wCloser != nil { + go func() { + <-src.ctx.Done() + wCloser.Close() + }() + } + if rCloser != nil { + go func() { + <-dst.ctx.Done() + rCloser.Close() + }() + } } } - buf := make([]byte, size) + flusher := getHttpFlusher(dst.Writer) + canFlush := flusher != nil for { - select { - case <-src.ctx.Done(): - return src.ctx.Err() - case <-dst.ctx.Done(): - return dst.ctx.Err() - default: - nr, er := src.Reader.Read(buf) - if nr > 0 { - nw, ew := dst.Writer.Write(buf[0:nr]) - if nw < 0 || nr < nw { - nw = 0 - if ew == nil { - ew = errors.New("invalid write result") - } - } - if ew != nil { - err = ew - return - } - if nr != nw { - err = io.ErrShortWrite - return + nr, er := src.Reader.Read(buf[:copyBufSize]) + if nr > 0 { + nw, ew := dst.Writer.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write result") } } - if er != nil { - if er != io.EOF { - err = er - } + if ew != nil { + err = ew return } + if nr != nw { + err = io.ErrShortWrite + return + } + if canFlush { + err = flusher.Flush() + if err != nil { + if errors.Is(err, http.ErrNotSupported) { + canFlush = false + err = nil + } else { + return err + } + } + } + } + if er != nil { + if er != io.EOF { + err = er + } + return } } } -func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { - return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) +func CopyCloseWithContext(ctx context.Context, dst io.Writer, src io.Reader) (err error) { + return CopyClose(NewContextWriter(ctx, dst), NewContextReader(ctx, src)) }