mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-02 05:14:25 +02:00
feat(idlesleep): support idlesleep for stream routes, rewritten and fixed stream implementation
This commit is contained in:
parent
25fbcc4ab9
commit
b5328fe5e7
16 changed files with 659 additions and 430 deletions
2
go.mod
2
go.mod
|
@ -222,7 +222,7 @@ require (
|
|||
go.opentelemetry.io/otel v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/mock v0.5.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package idlewatcher
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
func (w *Watcher) cancelled(reqCtx context.Context) bool {
|
||||
func (w *Watcher) canceled(reqCtx context.Context) bool {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
w.l.Debug().AnErr("cause", context.Cause(reqCtx)).Msg("wake canceled")
|
||||
|
|
|
@ -92,7 +92,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
}
|
||||
|
||||
ctx := r.Context()
|
||||
if w.cancelled(ctx) {
|
||||
if w.canceled(ctx) {
|
||||
w.redirectToStartEndpoint(rw, r)
|
||||
return false
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
|
|||
for {
|
||||
w.resetIdleTimer()
|
||||
|
||||
if w.cancelled(ctx) {
|
||||
if w.canceled(ctx) {
|
||||
w.redirectToStartEndpoint(rw, r)
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -5,45 +5,51 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
gpnet "github.com/yusing/go-proxy/internal/net/types"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Addr() net.Addr {
|
||||
return w.stream.Addr()
|
||||
var _ nettypes.Stream = (*Watcher)(nil)
|
||||
|
||||
// ListenAndServe implements nettypes.Stream.
|
||||
func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) {
|
||||
w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck
|
||||
return w.preDial(ctx, predial)
|
||||
}, func(ctx context.Context) error {
|
||||
return w.onRead(ctx, onRead)
|
||||
})
|
||||
}
|
||||
|
||||
// Setup implements types.Stream.
|
||||
func (w *Watcher) Setup() error {
|
||||
return w.stream.Setup()
|
||||
}
|
||||
|
||||
// Accept implements types.Stream.
|
||||
func (w *Watcher) Accept() (conn gpnet.StreamConn, err error) {
|
||||
conn, err = w.stream.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if wakeErr := w.wakeFromStream(); wakeErr != nil {
|
||||
w.l.Err(wakeErr).Msg("error waking container")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle implements types.Stream.
|
||||
func (w *Watcher) Handle(conn gpnet.StreamConn) error {
|
||||
if err := w.wakeFromStream(); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.stream.Handle(conn)
|
||||
}
|
||||
|
||||
// Close implements types.Stream.
|
||||
// Close implements nettypes.Stream.
|
||||
func (w *Watcher) Close() error {
|
||||
return w.stream.Close()
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromStream() error {
|
||||
// LocalAddr implements nettypes.Stream.
|
||||
func (w *Watcher) LocalAddr() net.Addr {
|
||||
return w.stream.LocalAddr()
|
||||
}
|
||||
|
||||
func (w *Watcher) preDial(ctx context.Context, predial nettypes.HookFunc) error {
|
||||
if predial != nil {
|
||||
if err := predial(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return w.wakeFromStream(ctx)
|
||||
}
|
||||
|
||||
func (w *Watcher) onRead(ctx context.Context, onRead nettypes.HookFunc) error {
|
||||
w.resetIdleTimer()
|
||||
if onRead != nil {
|
||||
if err := onRead(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Watcher) wakeFromStream(ctx context.Context) error {
|
||||
w.resetIdleTimer()
|
||||
|
||||
// pass through if container is already ready
|
||||
|
@ -52,18 +58,27 @@ func (w *Watcher) wakeFromStream() error {
|
|||
}
|
||||
|
||||
w.l.Debug().Msg("wake signal received")
|
||||
err := w.Wake(context.Background())
|
||||
err := w.Wake(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
w.resetIdleTimer()
|
||||
|
||||
if w.canceled(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !w.waitStarted(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ready, err := w.checkUpdateState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ready {
|
||||
w.resetIdleTimer()
|
||||
w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, passing through")
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -261,7 +261,7 @@ func NewWatcher(parent task.Parent, r routes.Route, cfg *idlewatcher.Config) (*W
|
|||
case routes.ReverseProxyRoute:
|
||||
w.rp = r.ReverseProxy()
|
||||
case routes.StreamRoute:
|
||||
w.stream = r
|
||||
w.stream = r.Stream()
|
||||
default:
|
||||
w.provider.Close()
|
||||
return nil, w.newWatcherError(gperr.Errorf("unexpected route type: %T", r))
|
||||
|
|
|
@ -6,9 +6,9 @@ import (
|
|||
)
|
||||
|
||||
type Stream interface {
|
||||
ListenAndServe(ctx context.Context, preDial PreDialFunc)
|
||||
ListenAndServe(ctx context.Context, preDial, onRead HookFunc)
|
||||
LocalAddr() net.Addr
|
||||
Close() error
|
||||
}
|
||||
|
||||
type PreDialFunc func(ctx context.Context) error
|
||||
type HookFunc func(ctx context.Context) error
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
|
||||
netutils "github.com/yusing/go-proxy/internal/net"
|
||||
net "github.com/yusing/go-proxy/internal/net/types"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/proxmox"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
|
@ -39,7 +39,7 @@ type (
|
|||
Alias string `json:"alias"`
|
||||
Scheme route.Scheme `json:"scheme,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port route.Port `json:"port,omitempty"`
|
||||
Port route.Port `json:"port"`
|
||||
Root string `json:"root,omitempty"`
|
||||
|
||||
route.HTTPConfig
|
||||
|
@ -64,8 +64,8 @@ type (
|
|||
Provider string `json:"provider,omitempty"` // for backward compatibility
|
||||
|
||||
// private fields
|
||||
LisURL *net.URL `json:"lurl,omitempty"`
|
||||
ProxyURL *net.URL `json:"purl,omitempty"`
|
||||
LisURL *nettypes.URL `json:"lurl,omitempty"`
|
||||
ProxyURL *nettypes.URL `json:"purl,omitempty"`
|
||||
|
||||
Excluded *bool `json:"excluded"`
|
||||
|
||||
|
@ -195,19 +195,19 @@ func (r *Route) Validate() gperr.Error {
|
|||
|
||||
switch r.Scheme {
|
||||
case route.SchemeFileServer:
|
||||
r.ProxyURL = gperr.Collect(errs, net.ParseURL, "file://"+r.Root)
|
||||
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, "file://"+r.Root)
|
||||
r.Host = ""
|
||||
r.Port.Proxy = 0
|
||||
case route.SchemeHTTP, route.SchemeHTTPS:
|
||||
if r.Port.Listening != 0 {
|
||||
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
|
||||
}
|
||||
r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
case route.SchemeTCP, route.SchemeUDP:
|
||||
if !r.ShouldExclude() {
|
||||
r.LisURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening))
|
||||
r.LisURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening))
|
||||
}
|
||||
r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
}
|
||||
|
||||
if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) {
|
||||
|
@ -309,7 +309,7 @@ func (r *Route) ProviderName() string {
|
|||
return r.Provider
|
||||
}
|
||||
|
||||
func (r *Route) TargetURL() *net.URL {
|
||||
func (r *Route) TargetURL() *nettypes.URL {
|
||||
return r.ProxyURL
|
||||
}
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ type (
|
|||
StreamRoute interface {
|
||||
Route
|
||||
nettypes.Stream
|
||||
Stream() nettypes.Stream
|
||||
}
|
||||
Provider interface {
|
||||
GetRoute(alias string) (r Route, ok bool)
|
||||
|
|
|
@ -2,7 +2,8 @@ package route
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/idlewatcher"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/route/routes"
|
||||
"github.com/yusing/go-proxy/internal/route/stream"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
|
||||
)
|
||||
|
@ -17,7 +19,7 @@ import (
|
|||
// TODO: support stream load balance.
|
||||
type StreamRoute struct {
|
||||
*Route
|
||||
nettypes.Stream `json:"-"`
|
||||
stream nettypes.Stream
|
||||
|
||||
l zerolog.Logger
|
||||
}
|
||||
|
@ -33,10 +35,19 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Stream() nettypes.Stream {
|
||||
return r.stream
|
||||
}
|
||||
|
||||
// Start implements task.TaskStarter.
|
||||
func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
||||
stream, err := r.initStream()
|
||||
if err != nil {
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
r.stream = stream
|
||||
|
||||
r.task = parent.Subtask("stream."+r.Name(), !r.ShouldExclude())
|
||||
r.Stream = NewStream(r)
|
||||
|
||||
switch {
|
||||
case r.UseIdleWatcher():
|
||||
|
@ -45,20 +56,12 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||
r.task.Finish(err)
|
||||
return gperr.Wrap(err, "idlewatcher error")
|
||||
}
|
||||
r.Stream = waker
|
||||
r.stream = waker
|
||||
r.HealthMon = waker
|
||||
case r.UseHealthCheck():
|
||||
r.HealthMon = monitor.NewMonitor(r)
|
||||
}
|
||||
|
||||
if !r.ShouldExclude() {
|
||||
if err := r.Setup(); err != nil {
|
||||
r.task.Finish(err)
|
||||
return gperr.Wrap(err)
|
||||
}
|
||||
r.l.Info().Int("port", r.Port.Listening).Msg("listening")
|
||||
}
|
||||
|
||||
if r.HealthMon != nil {
|
||||
if err := r.HealthMon.Start(r.task); err != nil {
|
||||
gperr.LogWarn("health monitor error", err, &r.l)
|
||||
|
@ -73,7 +76,14 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
go r.acceptConnections()
|
||||
r.ListenAndServe(r.task.Context(), nil, nil)
|
||||
r.l = r.l.With().Stringer("rurl", r.ProxyURL).Stringer("laddr", r.LocalAddr()).Logger()
|
||||
r.l.Info().Msg("stream started")
|
||||
|
||||
r.task.OnCancel("close_stream", func() {
|
||||
r.stream.Close()
|
||||
r.l.Info().Msg("stream closed")
|
||||
})
|
||||
|
||||
routes.Stream.Add(r)
|
||||
r.task.OnCancel("remove_route_from_stream", func() {
|
||||
|
@ -82,38 +92,34 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *StreamRoute) acceptConnections() {
|
||||
defer r.task.Finish("listener closed")
|
||||
|
||||
go func() {
|
||||
<-r.task.Context().Done()
|
||||
r.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
return
|
||||
default:
|
||||
conn, err := r.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.task.Context().Done():
|
||||
default:
|
||||
gperr.LogError("accept connection error", err, &r.l)
|
||||
}
|
||||
r.task.Finish(err)
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
panic("connection is nil")
|
||||
}
|
||||
go func() {
|
||||
err := r.Handle(conn)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
gperr.LogError("handle connection error", err, &r.l)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
func (r *StreamRoute) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
r.stream.ListenAndServe(ctx, preDial, onRead)
|
||||
}
|
||||
|
||||
func (r *StreamRoute) Close() error {
|
||||
return r.stream.Close()
|
||||
}
|
||||
|
||||
func (r *StreamRoute) LocalAddr() net.Addr {
|
||||
return r.stream.LocalAddr()
|
||||
}
|
||||
|
||||
func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||
lurl, rurl := r.LisURL, r.ProxyURL
|
||||
if lurl != nil && lurl.Scheme != rurl.Scheme {
|
||||
return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme)
|
||||
}
|
||||
|
||||
laddr := ":0"
|
||||
if lurl != nil {
|
||||
laddr = lurl.Host
|
||||
}
|
||||
|
||||
switch rurl.Scheme {
|
||||
case "tcp":
|
||||
return stream.NewTCPTCPStream(laddr, rurl.Host)
|
||||
case "udp":
|
||||
return stream.NewUDPUDPStream(laddr, rurl.Host)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||
}
|
||||
|
|
12
internal/route/stream/debug_debug.go
Normal file
12
internal/route/stream/debug_debug.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
//go:build debug
|
||||
|
||||
package stream
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {
|
||||
log.Debug().Object("stream", stream).Msgf(format, v...)
|
||||
}
|
7
internal/route/stream/debug_prod.go
Normal file
7
internal/route/stream/debug_prod.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
//go:build !debug
|
||||
|
||||
package stream
|
||||
|
||||
import "github.com/rs/zerolog"
|
||||
|
||||
func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {}
|
41
internal/route/stream/errors.go
Normal file
41
internal/route/stream/errors.go
Normal file
|
@ -0,0 +1,41 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func convertErr(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled),
|
||||
errors.Is(err, io.ErrClosedPipe),
|
||||
errors.Is(err, syscall.ECONNRESET):
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func logErr(stream zerolog.LogObjectMarshaler, err error, msg string) {
|
||||
err = convertErr(err)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Err(err).Object("stream", stream).Msg(msg)
|
||||
}
|
||||
|
||||
func logErrf(stream zerolog.LogObjectMarshaler, err error, format string, v ...any) {
|
||||
err = convertErr(err)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Err(err).Object("stream", stream).Msgf(format, v...)
|
||||
}
|
162
internal/route/stream/tcp_tcp.go
Normal file
162
internal/route/stream/tcp_tcp.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type TCPTCPStream struct {
|
||||
listener *net.TCPListener
|
||||
laddr *net.TCPAddr
|
||||
dst *net.TCPAddr
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
laddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
listener, err := net.ListenTCP("tcp", s.laddr)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to listen")
|
||||
return
|
||||
}
|
||||
s.listener = listener
|
||||
s.preDial = preDial
|
||||
s.onRead = onRead
|
||||
go s.listen(ctx)
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) Close() error {
|
||||
if s.closed.Swap(true) || s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) LocalAddr() net.Addr {
|
||||
if s.listener == nil {
|
||||
return s.laddr
|
||||
}
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String())
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) listen(ctx context.Context) {
|
||||
for {
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
logErr(s, err, "failed to accept connection")
|
||||
continue
|
||||
}
|
||||
if s.onRead != nil {
|
||||
if err := s.onRead(ctx); err != nil {
|
||||
logErr(s, err, "failed to on read")
|
||||
continue
|
||||
}
|
||||
}
|
||||
go s.handle(ctx, conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if s.preDial != nil {
|
||||
if err := s.preDial(ctx); err != nil {
|
||||
if !s.closed.Load() {
|
||||
logErr(s, err, "failed to pre-dial")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
||||
if err != nil {
|
||||
if !s.closed.Load() {
|
||||
logErr(s, err, "failed to dial destination")
|
||||
}
|
||||
return
|
||||
}
|
||||
defer dstConn.Close()
|
||||
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
src := conn
|
||||
dst := net.Conn(dstConn)
|
||||
if s.onRead != nil {
|
||||
src = &wrapperConn{
|
||||
Conn: conn,
|
||||
ctx: ctx,
|
||||
onRead: s.onRead,
|
||||
}
|
||||
dst = &wrapperConn{
|
||||
Conn: dstConn,
|
||||
ctx: ctx,
|
||||
onRead: s.onRead,
|
||||
}
|
||||
}
|
||||
|
||||
pipe := utils.NewBidirectionalPipe(ctx, src, dst)
|
||||
if err := pipe.Start(); err != nil && !s.closed.Load() {
|
||||
logErr(s, err, "error in bidirectional pipe")
|
||||
}
|
||||
}
|
||||
|
||||
type wrapperConn struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
onRead nettypes.HookFunc
|
||||
}
|
||||
|
||||
func (w *wrapperConn) Read(b []byte) (n int, err error) {
|
||||
n, err = w.Conn.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if w.onRead != nil {
|
||||
if err = w.onRead(w.ctx); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
316
internal/route/stream/udp_udp.go
Normal file
316
internal/route/stream/udp_udp.go
Normal file
|
@ -0,0 +1,316 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"maps"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/synk"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type UDPUDPStream struct {
|
||||
name string
|
||||
listener *net.UDPConn
|
||||
|
||||
laddr *net.UDPAddr
|
||||
dst *net.UDPAddr
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
|
||||
cleanUpTicker *time.Ticker
|
||||
|
||||
conns map[string]*udpUDPConn
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type udpUDPConn struct {
|
||||
srcAddr *net.UDPAddr
|
||||
dstConn *net.UDPConn
|
||||
listener *net.UDPConn
|
||||
lastUsed atomic.Time
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
const (
|
||||
udpBufferSize = 16 * 1024
|
||||
udpIdleTimeout = 5 * time.Minute // Longer timeout for game sessions
|
||||
udpCleanupInterval = 1 * time.Minute
|
||||
udpReadTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var bufPool = synk.NewBytesPool()
|
||||
|
||||
func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveUDPAddr("udp", dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UDPUDPStream{
|
||||
laddr: laddr,
|
||||
dst: dst,
|
||||
conns: make(map[string]*udpUDPConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
listener, err := net.ListenUDP("udp", s.laddr)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to listen")
|
||||
return
|
||||
}
|
||||
s.listener = listener
|
||||
s.preDial = preDial
|
||||
s.onRead = onRead
|
||||
go s.listen(ctx)
|
||||
go s.cleanUp(ctx)
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) Close() error {
|
||||
if s.closed.Swap(true) || s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
s.mu.Lock()
|
||||
for _, conn := range s.conns {
|
||||
wg.Add(1)
|
||||
go func(c *udpUDPConn) {
|
||||
defer wg.Done()
|
||||
c.Close()
|
||||
}(conn)
|
||||
}
|
||||
clear(s.conns)
|
||||
s.mu.Unlock()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) LocalAddr() net.Addr {
|
||||
if s.listener == nil {
|
||||
return s.laddr
|
||||
}
|
||||
return s.listener.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Str("protocol", "udp-udp").Str("name", s.name).Str("dst", s.dst.String())
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) listen(ctx context.Context) {
|
||||
buf := bufPool.GetSized(udpBufferSize)
|
||||
defer bufPool.Put(buf)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, srcAddr, err := s.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
if s.closed.Load() {
|
||||
return
|
||||
}
|
||||
logErr(s, err, "failed to read from listener")
|
||||
continue
|
||||
}
|
||||
|
||||
logDebugf(s, "read %d bytes from %s", n, srcAddr)
|
||||
|
||||
if s.onRead != nil {
|
||||
if err := s.onRead(ctx); err != nil {
|
||||
logErr(s, err, "failed to on read")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get or create connection, passing the initial data
|
||||
go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) getOrCreateConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) {
|
||||
key := srcAddr.String()
|
||||
|
||||
s.mu.Lock()
|
||||
if conn, ok := s.conns[key]; ok {
|
||||
s.mu.Unlock()
|
||||
// Forward packet for existing connection
|
||||
go conn.forwardToDestination(initialData)
|
||||
return
|
||||
}
|
||||
|
||||
defer s.mu.Unlock()
|
||||
// Create new connection with initial data
|
||||
conn, ok := s.createConnection(ctx, srcAddr, initialData)
|
||||
if ok && !conn.closed.Load() {
|
||||
s.conns[key] = conn
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) (*udpUDPConn, bool) {
|
||||
// Apply pre-dial if configured
|
||||
if s.preDial != nil {
|
||||
if err := s.preDial(ctx); err != nil {
|
||||
logErr(s, err, "failed to pre-dial")
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Create UDP connection to destination
|
||||
dstConn, err := net.DialUDP("udp", nil, s.dst)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to dial dst")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
conn := &udpUDPConn{
|
||||
srcAddr: srcAddr,
|
||||
dstConn: dstConn,
|
||||
listener: s.listener,
|
||||
}
|
||||
conn.lastUsed.Store(time.Now())
|
||||
|
||||
// Send initial data before starting response handler
|
||||
if !conn.forwardToDestination(initialData) {
|
||||
dstConn.Close()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Start response handler after initial data is sent
|
||||
go conn.handleResponses(ctx)
|
||||
|
||||
logDebugf(s, "created new connection from %s", srcAddr.String())
|
||||
return conn, true
|
||||
}
|
||||
|
||||
func (conn *udpUDPConn) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Stringer("src", conn.srcAddr).Stringer("dst", conn.dstConn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (conn *udpUDPConn) handleResponses(ctx context.Context) {
|
||||
buf := bufPool.GetSized(udpBufferSize)
|
||||
defer bufPool.Put(buf)
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
if conn.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// Set a reasonable timeout for reads
|
||||
_ = conn.dstConn.SetReadDeadline(time.Now().Add(udpReadTimeout))
|
||||
|
||||
n, err := conn.dstConn.Read(buf)
|
||||
if err != nil {
|
||||
if !conn.closed.Load() {
|
||||
logErr(conn, err, "failed to read from dst")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Clear deadline after successful read
|
||||
_ = conn.dstConn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Forward response back to client using the listener
|
||||
_, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr)
|
||||
if err != nil {
|
||||
if !conn.closed.Load() {
|
||||
logErrf(conn, err, "failed to write %d bytes to client", n)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn.lastUsed.Store(time.Now())
|
||||
logDebugf(conn, "forwarded response to client, %d bytes", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) cleanUp(ctx context.Context) {
|
||||
s.cleanUpTicker = time.NewTicker(udpCleanupInterval)
|
||||
defer s.cleanUpTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.cleanUpTicker.C:
|
||||
s.mu.Lock()
|
||||
conns := maps.Clone(s.conns)
|
||||
s.mu.Unlock()
|
||||
|
||||
removed := []string(nil)
|
||||
for key, conn := range conns {
|
||||
if conn.Expired() {
|
||||
conn.Close()
|
||||
removed = append(removed, key)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for _, key := range removed {
|
||||
logDebugf(s, "cleaning up expired connection: %s", key)
|
||||
delete(s.conns, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *udpUDPConn) forwardToDestination(data []byte) bool {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := conn.dstConn.Write(data)
|
||||
if err != nil {
|
||||
logErrf(conn, err, "failed to write %d bytes to dst", len(data))
|
||||
return false
|
||||
}
|
||||
|
||||
conn.lastUsed.Store(time.Now())
|
||||
logDebugf(conn, "forwarded %d bytes to dst", len(data))
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *udpUDPConn) Expired() bool {
|
||||
return time.Since(conn.lastUsed.Load()) > udpIdleTimeout
|
||||
}
|
||||
|
||||
func (conn *udpUDPConn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
conn.closed.Store(true)
|
||||
|
||||
conn.dstConn.Close()
|
||||
conn.dstConn = nil
|
||||
}
|
|
@ -1,129 +0,0 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type (
|
||||
Stream struct {
|
||||
*StreamRoute
|
||||
|
||||
listener types.StreamListener
|
||||
targetAddr net.Addr
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
streamFirstConnBufferSize = 128
|
||||
streamDialTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func NewStream(base *StreamRoute) *Stream {
|
||||
return &Stream{
|
||||
StreamRoute: base,
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Stream) Addr() net.Addr {
|
||||
if stream.listener == nil {
|
||||
panic("listener is nil")
|
||||
}
|
||||
return stream.listener.Addr()
|
||||
}
|
||||
|
||||
func (stream *Stream) Setup() error {
|
||||
var lcfg net.ListenConfig
|
||||
var err error
|
||||
|
||||
ctx := stream.task.Context()
|
||||
|
||||
switch stream.Scheme {
|
||||
case "tcp":
|
||||
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.ProxyURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpListener, err := lcfg.Listen(ctx, "tcp", stream.LisURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// in case ListeningPort was zero, get the actual port
|
||||
stream.Port.Listening = tcpListener.Addr().(*net.TCPAddr).Port
|
||||
stream.listener = types.NetListener(tcpListener)
|
||||
case "udp":
|
||||
stream.targetAddr, err = net.ResolveUDPAddr("udp", stream.ProxyURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.LisURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpConn, ok := udpListener.(*net.UDPConn)
|
||||
if !ok {
|
||||
udpListener.Close()
|
||||
return errors.New("udp listener is not *net.UDPConn")
|
||||
}
|
||||
stream.Port.Listening = udpConn.LocalAddr().(*net.UDPAddr).Port
|
||||
stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
|
||||
default:
|
||||
panic("should not reach here")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stream *Stream) Accept() (conn types.StreamConn, err error) {
|
||||
if stream.listener == nil {
|
||||
return nil, errors.New("listener is nil")
|
||||
}
|
||||
// prevent Accept from blocking forever
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
conn, err = stream.listener.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stream.task.Context().Done():
|
||||
stream.Close()
|
||||
return nil, stream.task.Context().Err()
|
||||
case <-done:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Stream) Handle(conn types.StreamConn) error {
|
||||
switch conn := conn.(type) {
|
||||
case *UDPConn:
|
||||
switch stream := stream.listener.(type) {
|
||||
case *UDPForwarder:
|
||||
return stream.Handle(conn)
|
||||
default:
|
||||
return fmt.Errorf("unexpected listener type: %T", stream)
|
||||
}
|
||||
case io.ReadWriteCloser:
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstConn.Close()
|
||||
defer conn.Close()
|
||||
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
|
||||
return pipe.Start()
|
||||
default:
|
||||
return fmt.Errorf("unexpected conn type: %T", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Stream) Close() error {
|
||||
return stream.listener.Close()
|
||||
}
|
|
@ -1,204 +0,0 @@
|
|||
package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
UDPForwarder struct {
|
||||
ctx context.Context
|
||||
forwarder *net.UDPConn
|
||||
dstAddr net.Addr
|
||||
connMap F.Map[string, *UDPConn]
|
||||
mu sync.Mutex
|
||||
}
|
||||
UDPConn struct {
|
||||
srcAddr *net.UDPAddr
|
||||
conn net.Conn
|
||||
buf *UDPBuf
|
||||
}
|
||||
UDPBuf struct {
|
||||
data, oob []byte
|
||||
n, oobn int
|
||||
}
|
||||
)
|
||||
|
||||
const udpConnBufferSize = 4096
|
||||
|
||||
func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder {
|
||||
return &UDPForwarder{
|
||||
ctx: ctx,
|
||||
forwarder: forwarder,
|
||||
dstAddr: dstAddr,
|
||||
connMap: F.NewMapOf[string, *UDPConn](),
|
||||
}
|
||||
}
|
||||
|
||||
func newUDPBuf() *UDPBuf {
|
||||
return &UDPBuf{
|
||||
data: make([]byte, udpConnBufferSize),
|
||||
oob: make([]byte, udpConnBufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *UDPConn) DstAddrString() string {
|
||||
return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Addr() net.Addr {
|
||||
return w.forwarder.LocalAddr()
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Accept() (types.StreamConn, error) {
|
||||
buf := newUDPBuf()
|
||||
addr, err := w.readFromListener(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UDPConn{
|
||||
srcAddr: addr,
|
||||
buf: buf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
|
||||
switch dstAddr := w.dstAddr.(type) {
|
||||
case *net.UDPAddr:
|
||||
var laddr *net.UDPAddr
|
||||
if dstAddr.IP.IsLoopback() {
|
||||
laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:")
|
||||
}
|
||||
dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr)
|
||||
case *net.TCPAddr:
|
||||
dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported network %s", w.dstAddr.Network())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
|
||||
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
|
||||
if err == nil {
|
||||
log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (conn *UDPConn) read() (err error) {
|
||||
switch dstConn := conn.conn.(type) {
|
||||
case *net.UDPConn:
|
||||
conn.buf.n, conn.buf.oobn, _, _, err = dstConn.ReadMsgUDP(conn.buf.data, conn.buf.oob)
|
||||
default:
|
||||
conn.buf.n, err = dstConn.Read(conn.buf.data[:conn.buf.n])
|
||||
conn.buf.oobn = 0
|
||||
}
|
||||
if err == nil {
|
||||
log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
|
||||
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
|
||||
if err == nil {
|
||||
log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (conn *UDPConn) write() (err error) {
|
||||
switch dstConn := conn.conn.(type) {
|
||||
case *net.UDPConn:
|
||||
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
|
||||
if err == nil {
|
||||
log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
|
||||
}
|
||||
default:
|
||||
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
|
||||
if err == nil {
|
||||
log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
dst, ok := w.connMap.Load(key)
|
||||
if !ok {
|
||||
var err error
|
||||
dst = conn
|
||||
dst.conn, err = w.dialDst()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := dst.write(); err != nil {
|
||||
dst.conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
w.connMap.Store(key, dst)
|
||||
} else {
|
||||
conn.conn = dst.conn
|
||||
if err := conn.write(); err != nil {
|
||||
w.connMap.Delete(key)
|
||||
dst.conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
|
||||
conn, ok := streamConn.(*UDPConn)
|
||||
if !ok {
|
||||
panic("unexpected conn type")
|
||||
}
|
||||
|
||||
key := conn.srcAddr.String()
|
||||
dst, err := w.getInitConn(conn, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
if err := dst.read(); err != nil {
|
||||
w.connMap.Delete(key)
|
||||
dst.conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Close() error {
|
||||
errs := gperr.NewBuilder("errors closing udp conn")
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.connMap.RangeAll(func(key string, conn *UDPConn) {
|
||||
errs.Add(conn.conn.Close())
|
||||
})
|
||||
w.connMap.Clear()
|
||||
errs.Add(w.forwarder.Close())
|
||||
return errs.Error()
|
||||
}
|
Loading…
Add table
Reference in a new issue