mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 12:42:34 +02:00
feat(http3): add HTTP/3 support and refactor server handling code into utility functions
This commit is contained in:
parent
c15c3b0fa1
commit
ac4b7e9490
3 changed files with 130 additions and 41 deletions
|
@ -19,6 +19,8 @@ var (
|
|||
IsDebug = GetEnvBool("DEBUG", IsTest)
|
||||
IsTrace = GetEnvBool("TRACE", false) && IsDebug
|
||||
|
||||
HTTP3Enabled = GetEnvBool("HTTP3_ENABLED", true)
|
||||
|
||||
ProxyHTTPAddr,
|
||||
ProxyHTTPHost,
|
||||
ProxyHTTPPort,
|
||||
|
|
|
@ -3,11 +3,11 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/go-proxy/internal/autocert"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
|
@ -33,6 +33,11 @@ type Options struct {
|
|||
Handler http.Handler
|
||||
}
|
||||
|
||||
type httpServer interface {
|
||||
*http.Server | *http3.Server
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
func StartServer(parent task.Parent, opt Options) (s *Server) {
|
||||
s = NewServer(opt)
|
||||
s.Start(parent)
|
||||
|
@ -82,67 +87,74 @@ func NewServer(opt Options) (s *Server) {
|
|||
func (s *Server) Start(parent task.Parent) {
|
||||
s.startTime = time.Now()
|
||||
subtask := parent.Subtask("server."+s.Name, false)
|
||||
|
||||
if s.https != nil && common.HTTP3Enabled {
|
||||
s.https.TLSConfig.NextProtos = []string{http3.NextProtoH3, "h2", "http/1.1"}
|
||||
h3 := &http3.Server{
|
||||
Addr: s.https.Addr,
|
||||
Handler: s.https.Handler,
|
||||
TLSConfig: http3.ConfigureTLSConfig(s.https.TLSConfig),
|
||||
}
|
||||
Start(subtask, h3, &s.l)
|
||||
s.http.Handler = advertiseHTTP3(s.http.Handler, h3)
|
||||
s.https.Handler = advertiseHTTP3(s.https.Handler, h3)
|
||||
}
|
||||
|
||||
Start(subtask, s.http, &s.l)
|
||||
Start(subtask, s.https, &s.l)
|
||||
}
|
||||
|
||||
func Start(parent task.Parent, srv *http.Server, logger *zerolog.Logger) {
|
||||
func Start[Server httpServer](parent task.Parent, srv Server, logger *zerolog.Logger) {
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
srv.BaseContext = func(l net.Listener) context.Context {
|
||||
return parent.Context()
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
srv.ErrorLog = log.New(logger, "", 0)
|
||||
}
|
||||
|
||||
var proto string
|
||||
if srv.TLSConfig == nil {
|
||||
proto = "http"
|
||||
} else {
|
||||
proto = "https"
|
||||
}
|
||||
setDebugLogger(srv, logger)
|
||||
|
||||
proto := proto(srv)
|
||||
task := parent.Subtask(proto, false)
|
||||
|
||||
var lc net.ListenConfig
|
||||
var serveFunc func() error
|
||||
|
||||
// Serve already closes the listener on return
|
||||
l, err := lc.Listen(task.Context(), "tcp", srv.Addr)
|
||||
if err != nil {
|
||||
HandleError(logger, err, "failed to listen on port")
|
||||
return
|
||||
}
|
||||
|
||||
task.OnCancel("stop", func() {
|
||||
Stop(srv, logger)
|
||||
})
|
||||
|
||||
logger.Info().Str("addr", srv.Addr).Msg("server started")
|
||||
|
||||
go func() {
|
||||
if srv.TLSConfig == nil {
|
||||
err = srv.Serve(l)
|
||||
} else {
|
||||
err = srv.Serve(tls.NewListener(l, srv.TLSConfig))
|
||||
switch srv := any(srv).(type) {
|
||||
case *http.Server:
|
||||
srv.BaseContext = func(l net.Listener) context.Context {
|
||||
return parent.Context()
|
||||
}
|
||||
l, err := lc.Listen(task.Context(), "tcp", srv.Addr)
|
||||
if err != nil {
|
||||
HandleError(logger, err, "failed to listen on port")
|
||||
return
|
||||
}
|
||||
if srv.TLSConfig != nil {
|
||||
l = tls.NewListener(l, srv.TLSConfig)
|
||||
}
|
||||
serveFunc = getServeFunc(l, srv.Serve)
|
||||
case *http3.Server:
|
||||
l, err := lc.ListenPacket(task.Context(), "udp", srv.Addr)
|
||||
if err != nil {
|
||||
HandleError(logger, err, "failed to listen on port")
|
||||
return
|
||||
}
|
||||
serveFunc = getServeFunc(l, srv.Serve)
|
||||
}
|
||||
task.OnCancel("stop", func() {
|
||||
stop(srv, logger)
|
||||
})
|
||||
logStarted(srv, logger)
|
||||
go func() {
|
||||
err := serveFunc()
|
||||
HandleError(logger, err, "failed to serve "+proto+" server")
|
||||
}()
|
||||
}
|
||||
|
||||
func Stop(srv *http.Server, logger *zerolog.Logger) {
|
||||
func stop[Server httpServer](srv Server, logger *zerolog.Logger) {
|
||||
if srv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var proto string
|
||||
if srv.TLSConfig == nil {
|
||||
proto = "http"
|
||||
} else {
|
||||
proto = "https"
|
||||
}
|
||||
proto := proto(srv)
|
||||
|
||||
ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
@ -150,7 +162,7 @@ func Stop(srv *http.Server, logger *zerolog.Logger) {
|
|||
if err := srv.Shutdown(ctx); err != nil {
|
||||
HandleError(logger, err, "failed to shutdown "+proto+" server")
|
||||
} else {
|
||||
logger.Info().Str("addr", srv.Addr).Msgf("server stopped")
|
||||
logger.Info().Str("proto", proto).Str("addr", addr(srv)).Msg("server stopped")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
75
internal/net/gphttp/server/utils.go
Normal file
75
internal/net/gphttp/server/utils.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog"
|
||||
slogzerolog "github.com/samber/slog-zerolog/v2"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
)
|
||||
|
||||
func advertiseHTTP3(handler http.Handler, h3 *http3.Server) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ProtoMajor < 3 {
|
||||
err := h3.SetQUICHeaders(w.Header())
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func proto[Server httpServer](srv Server) string {
|
||||
var proto string
|
||||
switch src := any(srv).(type) {
|
||||
case *http.Server:
|
||||
if src.TLSConfig == nil {
|
||||
proto = "http"
|
||||
} else {
|
||||
proto = "https"
|
||||
}
|
||||
case *http3.Server:
|
||||
proto = "h3"
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
func addr[Server httpServer](srv Server) string {
|
||||
var addr string
|
||||
switch src := any(srv).(type) {
|
||||
case *http.Server:
|
||||
addr = src.Addr
|
||||
case *http3.Server:
|
||||
addr = src.Addr
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func getServeFunc[listener any](l listener, serve func(listener) error) func() error {
|
||||
return func() error {
|
||||
return serve(l)
|
||||
}
|
||||
}
|
||||
|
||||
func setDebugLogger[Server httpServer](srv Server, logger *zerolog.Logger) {
|
||||
if !common.IsDebug {
|
||||
return
|
||||
}
|
||||
switch srv := any(srv).(type) {
|
||||
case *http.Server:
|
||||
srv.ErrorLog = log.New(logger, "", 0)
|
||||
case *http3.Server:
|
||||
logOpts := slogzerolog.Option{Level: slog.LevelDebug, Logger: logger}
|
||||
srv.Logger = slog.New(logOpts.NewZerologHandler())
|
||||
}
|
||||
}
|
||||
|
||||
func logStarted[Server httpServer](srv Server, logger *zerolog.Logger) {
|
||||
logger.Info().Str("proto", proto(srv)).Str("addr", addr(srv)).Msg("server started")
|
||||
}
|
Loading…
Add table
Reference in a new issue