diff --git a/internal/common/env.go b/internal/common/env.go index c73028b..767d411 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -19,6 +19,8 @@ var ( IsDebug = GetEnvBool("DEBUG", IsTest) IsTrace = GetEnvBool("TRACE", false) && IsDebug + HTTP3Enabled = GetEnvBool("HTTP3_ENABLED", true) + ProxyHTTPAddr, ProxyHTTPHost, ProxyHTTPPort, diff --git a/internal/net/gphttp/server/server.go b/internal/net/gphttp/server/server.go index 61f63c2..92408af 100644 --- a/internal/net/gphttp/server/server.go +++ b/internal/net/gphttp/server/server.go @@ -3,11 +3,11 @@ package server import ( "context" "crypto/tls" - "log" "net" "net/http" "time" + "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" @@ -33,6 +33,11 @@ type Options struct { Handler http.Handler } +type httpServer interface { + *http.Server | *http3.Server + Shutdown(ctx context.Context) error +} + func StartServer(parent task.Parent, opt Options) (s *Server) { s = NewServer(opt) s.Start(parent) @@ -82,67 +87,74 @@ func NewServer(opt Options) (s *Server) { func (s *Server) Start(parent task.Parent) { s.startTime = time.Now() subtask := parent.Subtask("server."+s.Name, false) + + if s.https != nil && common.HTTP3Enabled { + s.https.TLSConfig.NextProtos = []string{http3.NextProtoH3, "h2", "http/1.1"} + h3 := &http3.Server{ + Addr: s.https.Addr, + Handler: s.https.Handler, + TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig), + } + Start(subtask, h3, &s.l) + s.http.Handler = advertiseHTTP3(s.http.Handler, h3) + s.https.Handler = advertiseHTTP3(s.https.Handler, h3) + } + Start(subtask, s.http, &s.l) Start(subtask, s.https, &s.l) } -func Start(parent task.Parent, srv *http.Server, logger *zerolog.Logger) { +func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Logger) { if srv == nil { return } - srv.BaseContext = func(l net.Listener) context.Context { - return parent.Context() - } - if common.IsDebug { - srv.ErrorLog = log.New(logger, "", 0) - } - - var proto string - if srv.TLSConfig == nil { - proto = "http" - } else { - proto = "https" - } + setDebugLogger(srv, logger) + proto := proto(srv) task := parent.Subtask(proto, false) var lc net.ListenConfig + var serveFunc func() error - // Serve already closes the listener on return - l, err := lc.Listen(task.Context(), "tcp", srv.Addr) - if err != nil { - HandleError(logger, err, "failed to listen on port") - return - } - - task.OnCancel("stop", func() { - Stop(srv, logger) - }) - - logger.Info().Str("addr", srv.Addr).Msg("server started") - - go func() { - if srv.TLSConfig == nil { - err = srv.Serve(l) - } else { - err = srv.Serve(tls.NewListener(l, srv.TLSConfig)) + switch srv := any(srv).(type) { + case *http.Server: + srv.BaseContext = func(l net.Listener) context.Context { + return parent.Context() } + l, err := lc.Listen(task.Context(), "tcp", srv.Addr) + if err != nil { + HandleError(logger, err, "failed to listen on port") + return + } + if srv.TLSConfig != nil { + l = tls.NewListener(l, srv.TLSConfig) + } + serveFunc = getServeFunc(l, srv.Serve) + case *http3.Server: + l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr) + if err != nil { + HandleError(logger, err, "failed to listen on port") + return + } + serveFunc = getServeFunc(l, srv.Serve) + } + task.OnCancel("stop", func() { + stop(srv, logger) + }) + logStarted(srv, logger) + go func() { + err := serveFunc() HandleError(logger, err, "failed to serve "+proto+" server") }() } -func Stop(srv *http.Server, logger *zerolog.Logger) { +func stop[Server httpServer](srv Server, logger *zerolog.Logger) { if srv == nil { return } - var proto string - if srv.TLSConfig == nil { - proto = "http" - } else { - proto = "https" - } + proto := proto(srv) ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) defer cancel() @@ -150,7 +162,7 @@ func Stop(srv *http.Server, logger *zerolog.Logger) { if err := srv.Shutdown(ctx); err != nil { HandleError(logger, err, "failed to shutdown "+proto+" server") } else { - logger.Info().Str("addr", srv.Addr).Msgf("server stopped") + logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped") } } diff --git a/internal/net/gphttp/server/utils.go b/internal/net/gphttp/server/utils.go new file mode 100644 index 0000000..968f584 --- /dev/null +++ b/internal/net/gphttp/server/utils.go @@ -0,0 +1,75 @@ +package server + +import ( + "log" + "log/slog" + "net/http" + + "github.com/quic-go/quic-go/http3" + "github.com/rs/zerolog" + slogzerolog "github.com/samber/slog-zerolog/v2" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/net/gphttp" +) + +func advertiseHTTP3(handler http.Handler, h3 *http3.Server) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor < 3 { + err := h3.SetQUICHeaders(w.Header()) + if err != nil { + gphttp.ServerError(w, r, err) + return + } + } + handler.ServeHTTP(w, r) + }) +} + +func proto[Server httpServer](srv Server) string { + var proto string + switch src := any(srv).(type) { + case *http.Server: + if src.TLSConfig == nil { + proto = "http" + } else { + proto = "https" + } + case *http3.Server: + proto = "h3" + } + return proto +} + +func addr[Server httpServer](srv Server) string { + var addr string + switch src := any(srv).(type) { + case *http.Server: + addr = src.Addr + case *http3.Server: + addr = src.Addr + } + return addr +} + +func getServeFunc[listener any](l listener, serve func(listener) error) func() error { + return func() error { + return serve(l) + } +} + +func setDebugLogger[Server httpServer](srv Server, logger *zerolog.Logger) { + if !common.IsDebug { + return + } + switch srv := any(srv).(type) { + case *http.Server: + srv.ErrorLog = log.New(logger, "", 0) + case *http3.Server: + logOpts := slogzerolog.Option{Level: slog.LevelDebug, Logger: logger} + srv.Logger = slog.New(logOpts.NewZerologHandler()) + } +} + +func logStarted[Server httpServer](srv Server, logger *zerolog.Logger) { + logger.Info().Str("proto", proto(srv)).Str("addr", addr(srv)).Msg("server started") +}