feat(http3): add HTTP/3 support and refactor server handling code into utility functions

This commit is contained in:
yusing 2025-04-02 15:34:13 +08:00
parent c15c3b0fa1
commit ac4b7e9490
3 changed files with 130 additions and 41 deletions

View file

@ -19,6 +19,8 @@ var (
IsDebug = GetEnvBool("DEBUG", IsTest) IsDebug = GetEnvBool("DEBUG", IsTest)
IsTrace = GetEnvBool("TRACE", false) && IsDebug IsTrace = GetEnvBool("TRACE", false) && IsDebug
HTTP3Enabled = GetEnvBool("HTTP3_ENABLED", true)
ProxyHTTPAddr, ProxyHTTPAddr,
ProxyHTTPHost, ProxyHTTPHost,
ProxyHTTPPort, ProxyHTTPPort,

View file

@ -3,11 +3,11 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"log"
"net" "net"
"net/http" "net/http"
"time" "time"
"github.com/quic-go/quic-go/http3"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
@ -33,6 +33,11 @@ type Options struct {
Handler http.Handler Handler http.Handler
} }
type httpServer interface {
*http.Server | *http3.Server
Shutdown(ctx context.Context) error
}
func StartServer(parent task.Parent, opt Options) (s *Server) { func StartServer(parent task.Parent, opt Options) (s *Server) {
s = NewServer(opt) s = NewServer(opt)
s.Start(parent) s.Start(parent)
@ -82,67 +87,74 @@ func NewServer(opt Options) (s *Server) {
func (s *Server) Start(parent task.Parent) { func (s *Server) Start(parent task.Parent) {
s.startTime = time.Now() s.startTime = time.Now()
subtask := parent.Subtask("server."+s.Name, false) subtask := parent.Subtask("server."+s.Name, false)
if s.https != nil && common.HTTP3Enabled {
s.https.TLSConfig.NextProtos = []string{http3.NextProtoH3, "h2", "http/1.1"}
h3 := &http3.Server{
Addr: s.https.Addr,
Handler: s.https.Handler,
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
}
Start(subtask, h3, &s.l)
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
}
Start(subtask, s.http, &s.l) Start(subtask, s.http, &s.l)
Start(subtask, s.https, &s.l) Start(subtask, s.https, &s.l)
} }
func Start(parent task.Parent, srv *http.Server, logger *zerolog.Logger) { func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Logger) {
if srv == nil { if srv == nil {
return return
} }
srv.BaseContext = func(l net.Listener) context.Context {
return parent.Context()
}
if common.IsDebug { setDebugLogger(srv, logger)
srv.ErrorLog = log.New(logger, "", 0)
}
var proto string
if srv.TLSConfig == nil {
proto = "http"
} else {
proto = "https"
}
proto := proto(srv)
task := parent.Subtask(proto, false) task := parent.Subtask(proto, false)
var lc net.ListenConfig var lc net.ListenConfig
var serveFunc func() error
// Serve already closes the listener on return switch srv := any(srv).(type) {
l, err := lc.Listen(task.Context(), "tcp", srv.Addr) case *http.Server:
if err != nil { srv.BaseContext = func(l net.Listener) context.Context {
HandleError(logger, err, "failed to listen on port") return parent.Context()
return
}
task.OnCancel("stop", func() {
Stop(srv, logger)
})
logger.Info().Str("addr", srv.Addr).Msg("server started")
go func() {
if srv.TLSConfig == nil {
err = srv.Serve(l)
} else {
err = srv.Serve(tls.NewListener(l, srv.TLSConfig))
} }
l, err := lc.Listen(task.Context(), "tcp", srv.Addr)
if err != nil {
HandleError(logger, err, "failed to listen on port")
return
}
if srv.TLSConfig != nil {
l = tls.NewListener(l, srv.TLSConfig)
}
serveFunc = getServeFunc(l, srv.Serve)
case *http3.Server:
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
if err != nil {
HandleError(logger, err, "failed to listen on port")
return
}
serveFunc = getServeFunc(l, srv.Serve)
}
task.OnCancel("stop", func() {
stop(srv, logger)
})
logStarted(srv, logger)
go func() {
err := serveFunc()
HandleError(logger, err, "failed to serve "+proto+" server") HandleError(logger, err, "failed to serve "+proto+" server")
}() }()
} }
func Stop(srv *http.Server, logger *zerolog.Logger) { func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
if srv == nil { if srv == nil {
return return
} }
var proto string proto := proto(srv)
if srv.TLSConfig == nil {
proto = "http"
} else {
proto = "https"
}
ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second)
defer cancel() defer cancel()
@ -150,7 +162,7 @@ func Stop(srv *http.Server, logger *zerolog.Logger) {
if err := srv.Shutdown(ctx); err != nil { if err := srv.Shutdown(ctx); err != nil {
HandleError(logger, err, "failed to shutdown "+proto+" server") HandleError(logger, err, "failed to shutdown "+proto+" server")
} else { } else {
logger.Info().Str("addr", srv.Addr).Msgf("server stopped") logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")
} }
} }

View file

@ -0,0 +1,75 @@
package server
import (
"log"
"log/slog"
"net/http"
"github.com/quic-go/quic-go/http3"
"github.com/rs/zerolog"
slogzerolog "github.com/samber/slog-zerolog/v2"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/net/gphttp"
)
func advertiseHTTP3(handler http.Handler, h3 *http3.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor < 3 {
err := h3.SetQUICHeaders(w.Header())
if err != nil {
gphttp.ServerError(w, r, err)
return
}
}
handler.ServeHTTP(w, r)
})
}
func proto[Server httpServer](srv Server) string {
var proto string
switch src := any(srv).(type) {
case *http.Server:
if src.TLSConfig == nil {
proto = "http"
} else {
proto = "https"
}
case *http3.Server:
proto = "h3"
}
return proto
}
func addr[Server httpServer](srv Server) string {
var addr string
switch src := any(srv).(type) {
case *http.Server:
addr = src.Addr
case *http3.Server:
addr = src.Addr
}
return addr
}
func getServeFunc[listener any](l listener, serve func(listener) error) func() error {
return func() error {
return serve(l)
}
}
func setDebugLogger[Server httpServer](srv Server, logger *zerolog.Logger) {
if !common.IsDebug {
return
}
switch srv := any(srv).(type) {
case *http.Server:
srv.ErrorLog = log.New(logger, "", 0)
case *http3.Server:
logOpts := slogzerolog.Option{Level: slog.LevelDebug, Logger: logger}
srv.Logger = slog.New(logOpts.NewZerologHandler())
}
}
func logStarted[Server httpServer](srv Server, logger *zerolog.Logger) {
logger.Info().Str("proto", proto(srv)).Str("addr", addr(srv)).Msg("server started")
}