From 19e3392825997769612b3f1d89ec9170a444d613 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 13 Feb 2025 18:39:35 +0800 Subject: [PATCH] improve reverse proxy and serverhandling - buffer pool for IO copy - flush response after read, now works with event stream - fixed error handling for server --- agent/pkg/handler/docker_socket.go | 85 ++-------------- agent/pkg/server/server.go | 19 ++-- .../http/reverseproxy/reverse_proxy_mod.go | 4 +- internal/net/http/server/error.go | 4 +- internal/net/http/server/server.go | 33 ++++--- internal/route/rules/cache.go | 4 +- internal/utils/io.go | 99 ++++++++++++++++--- 7 files changed, 132 insertions(+), 116 deletions(-) diff --git a/agent/pkg/handler/docker_socket.go b/agent/pkg/handler/docker_socket.go index 53c0aea..e50ffaf 100644 --- a/agent/pkg/handler/docker_socket.go +++ b/agent/pkg/handler/docker_socket.go @@ -1,17 +1,15 @@ package handler import ( - "bufio" - "errors" - "io" "net/http" - "strings" + "net/url" - "github.com/yusing/go-proxy/internal/api/v1/utils" + "github.com/docker/docker/client" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/logging" - godoxyIO "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/net/http/reverseproxy" + "github.com/yusing/go-proxy/internal/net/types" ) func DockerSocketHandler() http.HandlerFunc { @@ -19,75 +17,10 @@ func DockerSocketHandler() http.HandlerFunc { if err != nil { logging.Fatal().Err(err).Msg("failed to connect to docker client") } - dockerDialerCallback := dockerClient.Dialer() + rp := reverseproxy.NewReverseProxy("docker", types.NewURL(&url.URL{ + Scheme: "http", + Host: client.DummyHost, + }), dockerClient.HTTPClient().Transport) - return func(w http.ResponseWriter, r *http.Request) { - conn, err := dockerDialerCallback(r.Context()) - if err != nil { - utils.HandleErr(w, r, err) - return - } - defer conn.Close() - - // Create a done channel to handle cancellation - done := make(chan struct{}) - defer close(done) - - closed := false - - // Start a goroutine to monitor context cancellation - go func() { - select { - case <-r.Context().Done(): - closed = true - conn.Close() // Force close the connection when client disconnects - case <-done: - } - }() - - if err := r.Write(conn); err != nil { - utils.HandleErr(w, r, err) - return - } - - resp, err := http.ReadResponse(bufio.NewReader(conn), r) - if err != nil { - utils.HandleErr(w, r, err) - return - } - defer resp.Body.Close() - - // Set any response headers before writing the status code - for k, v := range resp.Header { - w.Header()[k] = v - } - w.WriteHeader(resp.StatusCode) - - // For event streams, we need to flush the writer to ensure - // events are sent immediately - if f, ok := w.(http.Flusher); ok && strings.HasSuffix(r.URL.Path, "/events") { - // Copy the body in chunks and flush after each write - buf := make([]byte, 2048) - for { - n, err := resp.Body.Read(buf) - if n > 0 { - _, werr := w.Write(buf[:n]) - if werr != nil { - logging.Error().Err(werr).Msg("error writing docker event response") - break - } - f.Flush() - } - if err != nil { - if !closed && !errors.Is(err, io.EOF) { - logging.Error().Err(err).Msg("error reading docker event response") - } - return - } - } - } else { - // For non-event streams, just copy the body - _ = godoxyIO.NewPipe(r.Context(), resp.Body, NopWriteCloser{w}).Start() - } - } + return rp.ServeHTTP } diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index 369887f..3409f67 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -41,21 +41,22 @@ func StartAgentServer(parent task.Parent, opt Options) { tlsConfig.ClientAuth = tls.NoClientCert } + logger := logging.GetLogger() agentServer := &http.Server{ Handler: handler.NewAgentHandler(), TLSConfig: tlsConfig, - ErrorLog: log.New(logging.GetLogger(), "", 0), + ErrorLog: log.New(logger, "", 0), } go func() { l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port)) if err != nil { - logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to listen on port") + server.HandleError(logger, err, "failed to listen on port") return } defer l.Close() if err := agentServer.Serve(tls.NewListener(l, tlsConfig)); err != nil { - logging.Fatal().Err(err).Int("port", opt.Port).Msg("failed to serve") + server.HandleError(logger, err, "failed to serve agent server") } }() @@ -70,24 +71,26 @@ func StartAgentServer(parent task.Parent, opt Options) { err := agentServer.Shutdown(ctx) if err != nil { - logging.Error().Err(err).Int("port", opt.Port).Msg("failed to shutdown agent server") + server.HandleError(logger, err, "failed to shutdown agent server") + } else { + logging.Info().Int("port", opt.Port).Msg("agent server stopped") } - logging.Info().Int("port", opt.Port).Msg("agent server stopped") }() } func StartRegistrationServer(parent task.Parent, opt Options) { t := parent.Subtask("registration_server") + logger := logging.GetLogger() registrationServer := &http.Server{ Addr: fmt.Sprintf(":%d", opt.Port), Handler: handler.NewRegistrationHandler(t, opt.CACert), - ErrorLog: log.New(logging.GetLogger(), "", 0), + ErrorLog: log.New(logger, "", 0), } go func() { err := registrationServer.ListenAndServe() - server.HandleError(logging.GetLogger(), err) + server.HandleError(logger, err, "failed to serve registration server") }() logging.Info().Int("port", opt.Port).Msg("registration server started") @@ -99,7 +102,7 @@ func StartRegistrationServer(parent task.Parent, opt Options) { defer cancel() err := registrationServer.Shutdown(ctx) - server.HandleError(logging.GetLogger(), err) + server.HandleError(logger, err, "failed to shutdown registration server") logging.Info().Int("port", opt.Port).Msg("registration server stopped") } diff --git a/internal/net/http/reverseproxy/reverse_proxy_mod.go b/internal/net/http/reverseproxy/reverse_proxy_mod.go index 6d85eb0..fb0caa0 100644 --- a/internal/net/http/reverseproxy/reverse_proxy_mod.go +++ b/internal/net/http/reverseproxy/reverse_proxy_mod.go @@ -410,15 +410,13 @@ func (p *ReverseProxy) handler(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - _, err = io.Copy(rw, res.Body) + err = U.CopyClose(U.NewContextWriter(ctx, rw), U.NewContextReader(ctx, res.Body)) // close now, instead of defer, to populate res.Trailer if err != nil { if !errors.Is(err, context.Canceled) { p.errorHandler(rw, req, err, true) } - res.Body.Close() return } - res.Body.Close() // close now, instead of defer, to populate res.Trailer if len(res.Trailer) > 0 { // Force chunking if we saw a response trailer. diff --git a/internal/net/http/server/error.go b/internal/net/http/server/error.go index 3b43061..807950c 100644 --- a/internal/net/http/server/error.go +++ b/internal/net/http/server/error.go @@ -8,11 +8,11 @@ import ( "github.com/rs/zerolog" ) -func HandleError(logger *zerolog.Logger, err error) { +func HandleError(logger *zerolog.Logger, err error, msg string) { switch { case err == nil, errors.Is(err, http.ErrServerClosed), errors.Is(err, context.Canceled): return default: - logger.Fatal().Err(err).Msg("server error") + logger.Fatal().Err(err).Msg(msg) } } diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index fee526e..40c39c9 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -99,7 +99,10 @@ func (s *Server) Start(parent task.Parent) { s.startTime = time.Now() if s.http != nil { go func() { - s.handleErr(s.http.ListenAndServe()) + err := s.http.ListenAndServe() + if err != nil { + s.handleErr(err, "failed to serve http server") + } }() s.httpStarted = true s.l.Info().Str("addr", s.http.Addr).Msg("server started") @@ -109,11 +112,11 @@ func (s *Server) Start(parent task.Parent) { go func() { l, err := net.Listen("tcp", s.https.Addr) if err != nil { - s.handleErr(err) + s.handleErr(err, "failed to listen on port") return } defer l.Close() - s.handleErr(s.https.Serve(tls.NewListener(l, s.https.TLSConfig))) + s.handleErr(s.https.Serve(tls.NewListener(l, s.https.TLSConfig)), "failed to serve https server") }() s.httpsStarted = true s.l.Info().Str("addr", s.https.Addr).Msgf("server started") @@ -131,15 +134,23 @@ func (s *Server) stop() { defer cancel() if s.http != nil && s.httpStarted { - s.handleErr(s.http.Shutdown(ctx)) - s.httpStarted = false - s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped") + err := s.http.Shutdown(ctx) + if err != nil { + s.handleErr(err, "failed to shutdown http server") + } else { + s.httpStarted = false + s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped") + } } if s.https != nil && s.httpsStarted { - s.handleErr(s.https.Shutdown(ctx)) - s.httpsStarted = false - s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped") + err := s.https.Shutdown(ctx) + if err != nil { + s.handleErr(err, "failed to shutdown https server") + } else { + s.httpsStarted = false + s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped") + } } } @@ -147,6 +158,6 @@ func (s *Server) Uptime() time.Duration { return time.Since(s.startTime) } -func (s *Server) handleErr(err error) { - HandleError(&s.l, err) +func (s *Server) handleErr(err error, msg string) { + HandleError(&s.l, err, msg) } diff --git a/internal/route/rules/cache.go b/internal/route/rules/cache.go index 532f508..c57965a 100644 --- a/internal/route/rules/cache.go +++ b/internal/route/rules/cache.go @@ -41,9 +41,7 @@ func NewCache() Cache { // Release clear the contents of the Cached and returns it to the pool. func (c Cache) Release() { - for _, k := range cacheKeys { - delete(c, k) - } + clear(c) cachePool.Put(c) } diff --git a/internal/utils/io.go b/internal/utils/io.go index 0e65c46..4ff8653 100644 --- a/internal/utils/io.go +++ b/internal/utils/io.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "net/http" "sync" "syscall" @@ -37,6 +38,14 @@ type ( } ) +func NewContextReader(ctx context.Context, r io.Reader) *ContextReader { + return &ContextReader{ctx: ctx, Reader: r} +} + +func NewContextWriter(ctx context.Context, w io.Writer) *ContextWriter { + return &ContextWriter{ctx: ctx, Writer: w} +} + func (r *ContextReader) Read(p []byte) (int, error) { select { case <-r.ctx.Done(): @@ -63,7 +72,7 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { } func (p *Pipe) Start() (err error) { - err = Copy(&p.w, &p.r) + err = CopyClose(&p.w, &p.r) switch { case // NOTE: ignoring broken pipe and connection reset by peer @@ -97,20 +106,78 @@ func (p BidirectionalPipe) Start() E.Error { return b.Error() } +var copyBufPool = sync.Pool{ + New: func() any { + return make([]byte, copyBufSize) + }, +} + +type httpFlusher interface { + Flush() error +} + +func getHttpFlusher(dst io.Writer) httpFlusher { + if rw, ok := dst.(http.ResponseWriter); ok { + return http.NewResponseController(rw) + } + return nil +} + +const ( + copyBufSize = 32 * 1024 +) + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// This is a copy of io.Copy with context handling +// This is a copy of io.Copy with context and HTTP flusher handling // Author: yusing . -func Copy(dst *ContextWriter, src *ContextReader) (err error) { - size := 32 * 1024 - if l, ok := src.Reader.(*io.LimitedReader); ok && int64(size) > l.N { - if l.N < 1 { - size = 1 +func CopyClose(dst *ContextWriter, src *ContextReader) (err error) { + var buf []byte + if l, ok := src.Reader.(*io.LimitedReader); ok { + size := copyBufSize + if int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + buf = make([]byte, size) + } else { + buf = copyBufPool.Get().([]byte) + defer copyBufPool.Put(buf) + } + // close both as soon as one of them is done + wCloser, wCanClose := dst.Writer.(io.Closer) + rCloser, rCanClose := src.Reader.(io.Closer) + if wCanClose || rCanClose { + if src.ctx == dst.ctx { + go func() { + <-src.ctx.Done() + if wCanClose { + wCloser.Close() + } + if rCanClose { + rCloser.Close() + } + }() } else { - size = int(l.N) + if wCloser != nil { + go func() { + <-src.ctx.Done() + wCloser.Close() + }() + } + if rCloser != nil { + go func() { + <-dst.ctx.Done() + rCloser.Close() + }() + } } } - buf := make([]byte, size) + flusher := getHttpFlusher(dst.Writer) + canFlush := flusher != nil for { select { case <-src.ctx.Done(): @@ -135,6 +202,16 @@ func Copy(dst *ContextWriter, src *ContextReader) (err error) { err = io.ErrShortWrite return } + if canFlush { + err = flusher.Flush() + if err != nil { + if errors.Is(err, http.ErrNotSupported) { + canFlush = false + } else { + return err + } + } + } } if er != nil { if er != io.EOF { @@ -145,7 +222,3 @@ func Copy(dst *ContextWriter, src *ContextReader) (err error) { } } } - -func Copy2(ctx context.Context, dst io.Writer, src io.Reader) error { - return Copy(&ContextWriter{ctx: ctx, Writer: dst}, &ContextReader{ctx: ctx, Reader: src}) -}