diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index fb0910a..9d6c336 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -1,14 +1,11 @@ package server import ( - "context" "crypto/tls" "crypto/x509" "encoding/pem" "fmt" - "net" "net/http" - "time" "github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/handler" @@ -42,36 +39,13 @@ func StartAgentServer(parent task.Parent, opt Options) { logger := logging.GetLogger() agentServer := &http.Server{ + Addr: fmt.Sprintf(":%d", opt.Port), Handler: handler.NewAgentHandler(), TLSConfig: tlsConfig, } - go func() { - l, err := net.Listen("tcp", fmt.Sprintf(":%d", opt.Port)) - if err != nil { - server.HandleError(logger, err, "failed to listen on port") - return - } - defer l.Close() - if err := agentServer.Serve(tls.NewListener(l, tlsConfig)); err != nil { - server.HandleError(logger, err, "failed to serve agent server") - } - }() - - logging.Info().Int("port", opt.Port).Msg("agent server started") - - go func() { - defer t.Finish(nil) - <-parent.Context().Done() - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - err := agentServer.Shutdown(ctx) - if err != nil { - server.HandleError(logger, err, "failed to shutdown agent server") - } else { - logging.Info().Int("port", opt.Port).Msg("agent server stopped") - } - }() + server.Start(t, agentServer, logger) + t.OnCancel("stop", func() { + server.Stop(agentServer, logger) + }) } diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index 752cec7..1e7b132 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "crypto/tls" - "io" "log" "net" "net/http" @@ -21,8 +20,6 @@ type Server struct { CertProvider *autocert.Provider http *http.Server https *http.Server - httpStarted bool - httpsStarted bool startTime time.Time l zerolog.Logger @@ -53,23 +50,16 @@ func NewServer(opt Options) (s *Server) { certAvailable = err == nil } - out := io.Discard - if common.IsDebug { - out = logger - } - if opt.HTTPAddr != "" { httpSer = &http.Server{ - Addr: opt.HTTPAddr, - Handler: opt.Handler, - ErrorLog: log.New(out, "", 0), // most are tls related + Addr: opt.HTTPAddr, + Handler: opt.Handler, } } if certAvailable && opt.HTTPSAddr != "" { httpsSer = &http.Server{ - Addr: opt.HTTPSAddr, - Handler: opt.Handler, - ErrorLog: log.New(out, "", 0), // most are tls related + Addr: opt.HTTPSAddr, + Handler: opt.Handler, TLSConfig: &tls.Config{ GetCertificate: opt.CertProvider.GetCert, }, @@ -90,74 +80,80 @@ func NewServer(opt Options) (s *Server) { // // Start() is non-blocking. func (s *Server) Start(parent task.Parent) { - if s.http == nil && s.https == nil { - return - } - - task := parent.Subtask("server."+s.Name, false) - s.startTime = time.Now() - if s.http != nil { - go func() { - err := s.http.ListenAndServe() - if err != nil { - s.handleErr(err, "failed to serve http server") - } - }() - s.httpStarted = true - s.l.Info().Str("addr", s.http.Addr).Msg("server started") - } - - if s.https != nil { - go func() { - l, err := net.Listen("tcp", s.https.Addr) - if err != nil { - s.handleErr(err, "failed to listen on port") - return - } - defer l.Close() - s.handleErr(s.https.Serve(tls.NewListener(l, s.https.TLSConfig)), "failed to serve https server") - }() - s.httpsStarted = true - s.l.Info().Str("addr", s.https.Addr).Msgf("server started") - } - - task.OnCancel("stop", s.stop) + subtask := parent.Subtask("server."+s.Name, false) + Start(subtask, s.http, &s.l) + Start(subtask, s.https, &s.l) } -func (s *Server) stop() { - if s.http == nil && s.https == nil { +func Start(parent task.Parent, srv *http.Server, logger *zerolog.Logger) { + if srv == nil { + return + } + srv.BaseContext = func(l net.Listener) context.Context { + return parent.Context() + } + + if common.IsDebug { + srv.ErrorLog = log.New(logger, "", 0) + } + + var proto string + if srv.TLSConfig == nil { + proto = "http" + } else { + proto = "https" + } + + task := parent.Subtask(proto, false) + + var lc net.ListenConfig + + go func() { + // Serve already closes the listener on return + l, err := lc.Listen(task.Context(), "tcp", srv.Addr) + if err != nil { + HandleError(logger, err, "failed to listen on port") + return + } + + task.OnCancel("stop", func() { + Stop(srv, logger) + }) + + logger.Info().Str("addr", srv.Addr).Msg("server started") + + if srv.TLSConfig == nil { + err = srv.Serve(l) + } else { + err = srv.Serve(tls.NewListener(l, srv.TLSConfig)) + } + HandleError(logger, err, "failed to serve "+proto+" server") + }() +} + +func Stop(srv *http.Server, logger *zerolog.Logger) { + if srv == nil { return } - ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second) - defer cancel() - - if s.http != nil && s.httpStarted { - err := s.http.Shutdown(ctx) - if err != nil { - s.handleErr(err, "failed to shutdown http server") - } else { - s.httpStarted = false - s.l.Info().Str("addr", s.http.Addr).Msgf("server stopped") - } + var proto string + if srv.TLSConfig == nil { + proto = "http" + } else { + proto = "https" } - if s.https != nil && s.httpsStarted { - err := s.https.Shutdown(ctx) - if err != nil { - s.handleErr(err, "failed to shutdown https server") - } else { - s.httpsStarted = false - s.l.Info().Str("addr", s.https.Addr).Msgf("server stopped") - } + ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + HandleError(logger, err, "failed to shutdown "+proto+" server") + } else { + logger.Info().Str("addr", srv.Addr).Msgf("server stopped") } } func (s *Server) Uptime() time.Duration { return time.Since(s.startTime) } - -func (s *Server) handleErr(err error, msg string) { - HandleError(&s.l, err, msg) -}