feat(idlesleep): support idlesleep for stream routes, rewritten and fixed stream implementation

This commit is contained in:
yusing 2025-06-09 22:20:26 +08:00
parent 25fbcc4ab9
commit b5328fe5e7
16 changed files with 659 additions and 430 deletions

2
go.mod
View file

@ -222,7 +222,7 @@ require (
go.opentelemetry.io/otel v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/atomic v1.11.0
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/mock v0.5.2 // indirect
go.uber.org/multierr v1.11.0 // indirect

View file

@ -1,8 +1,10 @@
package idlewatcher
import "context"
import (
"context"
)
func (w *Watcher) cancelled(reqCtx context.Context) bool {
func (w *Watcher) canceled(reqCtx context.Context) bool {
select {
case <-reqCtx.Done():
w.l.Debug().AnErr("cause", context.Cause(reqCtx)).Msg("wake canceled")

View file

@ -92,7 +92,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
}
ctx := r.Context()
if w.cancelled(ctx) {
if w.canceled(ctx) {
w.redirectToStartEndpoint(rw, r)
return false
}
@ -107,7 +107,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN
for {
w.resetIdleTimer()
if w.cancelled(ctx) {
if w.canceled(ctx) {
w.redirectToStartEndpoint(rw, r)
return false
}

View file

@ -5,45 +5,51 @@ import (
"net"
"time"
gpnet "github.com/yusing/go-proxy/internal/net/types"
nettypes "github.com/yusing/go-proxy/internal/net/types"
)
// Setup implements types.Stream.
func (w *Watcher) Addr() net.Addr {
return w.stream.Addr()
var _ nettypes.Stream = (*Watcher)(nil)
// ListenAndServe implements nettypes.Stream.
func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) {
w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck
return w.preDial(ctx, predial)
}, func(ctx context.Context) error {
return w.onRead(ctx, onRead)
})
}
// Setup implements types.Stream.
func (w *Watcher) Setup() error {
return w.stream.Setup()
}
// Accept implements types.Stream.
func (w *Watcher) Accept() (conn gpnet.StreamConn, err error) {
conn, err = w.stream.Accept()
if err != nil {
return
}
if wakeErr := w.wakeFromStream(); wakeErr != nil {
w.l.Err(wakeErr).Msg("error waking container")
}
return
}
// Handle implements types.Stream.
func (w *Watcher) Handle(conn gpnet.StreamConn) error {
if err := w.wakeFromStream(); err != nil {
return err
}
return w.stream.Handle(conn)
}
// Close implements types.Stream.
// Close implements nettypes.Stream.
func (w *Watcher) Close() error {
return w.stream.Close()
}
func (w *Watcher) wakeFromStream() error {
// LocalAddr implements nettypes.Stream.
func (w *Watcher) LocalAddr() net.Addr {
return w.stream.LocalAddr()
}
func (w *Watcher) preDial(ctx context.Context, predial nettypes.HookFunc) error {
if predial != nil {
if err := predial(ctx); err != nil {
return err
}
}
return w.wakeFromStream(ctx)
}
func (w *Watcher) onRead(ctx context.Context, onRead nettypes.HookFunc) error {
w.resetIdleTimer()
if onRead != nil {
if err := onRead(ctx); err != nil {
return err
}
}
return nil
}
func (w *Watcher) wakeFromStream(ctx context.Context) error {
w.resetIdleTimer()
// pass through if container is already ready
@ -52,18 +58,27 @@ func (w *Watcher) wakeFromStream() error {
}
w.l.Debug().Msg("wake signal received")
err := w.Wake(context.Background())
err := w.Wake(ctx)
if err != nil {
return err
}
for {
w.resetIdleTimer()
if w.canceled(ctx) {
return nil
}
if !w.waitStarted(ctx) {
return nil
}
ready, err := w.checkUpdateState()
if err != nil {
return err
}
if ready {
w.resetIdleTimer()
w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, passing through")
return nil
}

View file

@ -261,7 +261,7 @@ func NewWatcher(parent task.Parent, r routes.Route, cfg *idlewatcher.Config) (*W
case routes.ReverseProxyRoute:
w.rp = r.ReverseProxy()
case routes.StreamRoute:
w.stream = r
w.stream = r.Stream()
default:
w.provider.Close()
return nil, w.newWatcherError(gperr.Errorf("unexpected route type: %T", r))

View file

@ -6,9 +6,9 @@ import (
)
type Stream interface {
ListenAndServe(ctx context.Context, preDial PreDialFunc)
ListenAndServe(ctx context.Context, preDial, onRead HookFunc)
LocalAddr() net.Addr
Close() error
}
type PreDialFunc func(ctx context.Context) error
type HookFunc func(ctx context.Context) error

View file

@ -16,7 +16,7 @@ import (
"github.com/yusing/go-proxy/internal/homepage"
idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types"
netutils "github.com/yusing/go-proxy/internal/net"
net "github.com/yusing/go-proxy/internal/net/types"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/proxmox"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils"
@ -39,7 +39,7 @@ type (
Alias string `json:"alias"`
Scheme route.Scheme `json:"scheme,omitempty"`
Host string `json:"host,omitempty"`
Port route.Port `json:"port,omitempty"`
Port route.Port `json:"port"`
Root string `json:"root,omitempty"`
route.HTTPConfig
@ -64,8 +64,8 @@ type (
Provider string `json:"provider,omitempty"` // for backward compatibility
// private fields
LisURL *net.URL `json:"lurl,omitempty"`
ProxyURL *net.URL `json:"purl,omitempty"`
LisURL *nettypes.URL `json:"lurl,omitempty"`
ProxyURL *nettypes.URL `json:"purl,omitempty"`
Excluded *bool `json:"excluded"`
@ -195,19 +195,19 @@ func (r *Route) Validate() gperr.Error {
switch r.Scheme {
case route.SchemeFileServer:
r.ProxyURL = gperr.Collect(errs, net.ParseURL, "file://"+r.Root)
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, "file://"+r.Root)
r.Host = ""
r.Port.Proxy = 0
case route.SchemeHTTP, route.SchemeHTTPS:
if r.Port.Listening != 0 {
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
}
r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
case route.SchemeTCP, route.SchemeUDP:
if !r.ShouldExclude() {
r.LisURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening))
r.LisURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening))
}
r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
}
if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) {
@ -309,7 +309,7 @@ func (r *Route) ProviderName() string {
return r.Provider
}
func (r *Route) TargetURL() *net.URL {
func (r *Route) TargetURL() *nettypes.URL {
return r.ProxyURL
}

View file

@ -58,6 +58,7 @@ type (
StreamRoute interface {
Route
nettypes.Stream
Stream() nettypes.Stream
}
Provider interface {
GetRoute(alias string) (r Route, ok bool)

View file

@ -2,7 +2,8 @@ package route
import (
"context"
"errors"
"fmt"
"net"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -10,6 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/idlewatcher"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/route/stream"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher/health/monitor"
)
@ -17,7 +19,7 @@ import (
// TODO: support stream load balance.
type StreamRoute struct {
*Route
nettypes.Stream `json:"-"`
stream nettypes.Stream
l zerolog.Logger
}
@ -33,10 +35,19 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) {
}, nil
}
func (r *StreamRoute) Stream() nettypes.Stream {
return r.stream
}
// Start implements task.TaskStarter.
func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
stream, err := r.initStream()
if err != nil {
return gperr.Wrap(err)
}
r.stream = stream
r.task = parent.Subtask("stream."+r.Name(), !r.ShouldExclude())
r.Stream = NewStream(r)
switch {
case r.UseIdleWatcher():
@ -45,20 +56,12 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
r.task.Finish(err)
return gperr.Wrap(err, "idlewatcher error")
}
r.Stream = waker
r.stream = waker
r.HealthMon = waker
case r.UseHealthCheck():
r.HealthMon = monitor.NewMonitor(r)
}
if !r.ShouldExclude() {
if err := r.Setup(); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}
r.l.Info().Int("port", r.Port.Listening).Msg("listening")
}
if r.HealthMon != nil {
if err := r.HealthMon.Start(r.task); err != nil {
gperr.LogWarn("health monitor error", err, &r.l)
@ -73,7 +76,14 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
return err
}
go r.acceptConnections()
r.ListenAndServe(r.task.Context(), nil, nil)
r.l = r.l.With().Stringer("rurl", r.ProxyURL).Stringer("laddr", r.LocalAddr()).Logger()
r.l.Info().Msg("stream started")
r.task.OnCancel("close_stream", func() {
r.stream.Close()
r.l.Info().Msg("stream closed")
})
routes.Stream.Add(r)
r.task.OnCancel("remove_route_from_stream", func() {
@ -82,38 +92,34 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
return nil
}
func (r *StreamRoute) acceptConnections() {
defer r.task.Finish("listener closed")
go func() {
<-r.task.Context().Done()
r.Close()
}()
for {
select {
case <-r.task.Context().Done():
return
default:
conn, err := r.Accept()
if err != nil {
select {
case <-r.task.Context().Done():
default:
gperr.LogError("accept connection error", err, &r.l)
}
r.task.Finish(err)
return
}
if conn == nil {
panic("connection is nil")
}
go func() {
err := r.Handle(conn)
if err != nil && !errors.Is(err, context.Canceled) {
gperr.LogError("handle connection error", err, &r.l)
}
}()
}
}
func (r *StreamRoute) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
r.stream.ListenAndServe(ctx, preDial, onRead)
}
func (r *StreamRoute) Close() error {
return r.stream.Close()
}
func (r *StreamRoute) LocalAddr() net.Addr {
return r.stream.LocalAddr()
}
func (r *StreamRoute) initStream() (nettypes.Stream, error) {
lurl, rurl := r.LisURL, r.ProxyURL
if lurl != nil && lurl.Scheme != rurl.Scheme {
return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme)
}
laddr := ":0"
if lurl != nil {
laddr = lurl.Host
}
switch rurl.Scheme {
case "tcp":
return stream.NewTCPTCPStream(laddr, rurl.Host)
case "udp":
return stream.NewUDPUDPStream(laddr, rurl.Host)
}
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
}

View file

@ -0,0 +1,12 @@
//go:build debug
package stream
import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {
log.Debug().Object("stream", stream).Msgf(format, v...)
}

View file

@ -0,0 +1,7 @@
//go:build !debug
package stream
import "github.com/rs/zerolog"
func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {}

View file

@ -0,0 +1,41 @@
package stream
import (
"context"
"errors"
"io"
"syscall"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
func convertErr(err error) error {
if err == nil {
return nil
}
switch {
case errors.Is(err, context.Canceled),
errors.Is(err, io.ErrClosedPipe),
errors.Is(err, syscall.ECONNRESET):
return nil
default:
return err
}
}
func logErr(stream zerolog.LogObjectMarshaler, err error, msg string) {
err = convertErr(err)
if err == nil {
return
}
log.Err(err).Object("stream", stream).Msg(msg)
}
func logErrf(stream zerolog.LogObjectMarshaler, err error, format string, v ...any) {
err = convertErr(err)
if err == nil {
return
}
log.Err(err).Object("stream", stream).Msgf(format, v...)
}

View file

@ -0,0 +1,162 @@
package stream
import (
"context"
"net"
"github.com/rs/zerolog"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
"go.uber.org/atomic"
)
type TCPTCPStream struct {
listener *net.TCPListener
laddr *net.TCPAddr
dst *net.TCPAddr
preDial nettypes.HookFunc
onRead nettypes.HookFunc
closed atomic.Bool
}
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
}
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
listener, err := net.ListenTCP("tcp", s.laddr)
if err != nil {
logErr(s, err, "failed to listen")
return
}
s.listener = listener
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)
}
func (s *TCPTCPStream) Close() error {
if s.closed.Swap(true) || s.listener == nil {
return nil
}
return s.listener.Close()
}
func (s *TCPTCPStream) LocalAddr() net.Addr {
if s.listener == nil {
return s.laddr
}
return s.listener.Addr()
}
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String())
}
func (s *TCPTCPStream) listen(ctx context.Context) {
for {
if s.closed.Load() {
return
}
select {
case <-ctx.Done():
return
default:
conn, err := s.listener.Accept()
if err != nil {
if s.closed.Load() {
return
}
logErr(s, err, "failed to accept connection")
continue
}
if s.onRead != nil {
if err := s.onRead(ctx); err != nil {
logErr(s, err, "failed to on read")
continue
}
}
go s.handle(ctx, conn)
}
}
}
func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
if s.preDial != nil {
if err := s.preDial(ctx); err != nil {
if !s.closed.Load() {
logErr(s, err, "failed to pre-dial")
}
return
}
}
if s.closed.Load() {
return
}
dstConn, err := net.DialTCP("tcp", nil, s.dst)
if err != nil {
if !s.closed.Load() {
logErr(s, err, "failed to dial destination")
}
return
}
defer dstConn.Close()
if s.closed.Load() {
return
}
src := conn
dst := net.Conn(dstConn)
if s.onRead != nil {
src = &wrapperConn{
Conn: conn,
ctx: ctx,
onRead: s.onRead,
}
dst = &wrapperConn{
Conn: dstConn,
ctx: ctx,
onRead: s.onRead,
}
}
pipe := utils.NewBidirectionalPipe(ctx, src, dst)
if err := pipe.Start(); err != nil && !s.closed.Load() {
logErr(s, err, "error in bidirectional pipe")
}
}
type wrapperConn struct {
net.Conn
ctx context.Context
onRead nettypes.HookFunc
}
func (w *wrapperConn) Read(b []byte) (n int, err error) {
n, err = w.Conn.Read(b)
if err != nil {
return
}
if w.onRead != nil {
if err = w.onRead(w.ctx); err != nil {
return
}
}
return
}

View file

@ -0,0 +1,316 @@
package stream
import (
"bytes"
"context"
"maps"
"net"
"sync"
"time"
"github.com/rs/zerolog"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/synk"
"go.uber.org/atomic"
)
type UDPUDPStream struct {
name string
listener *net.UDPConn
laddr *net.UDPAddr
dst *net.UDPAddr
preDial nettypes.HookFunc
onRead nettypes.HookFunc
cleanUpTicker *time.Ticker
conns map[string]*udpUDPConn
closed atomic.Bool
mu sync.Mutex
}
type udpUDPConn struct {
srcAddr *net.UDPAddr
dstConn *net.UDPConn
listener *net.UDPConn
lastUsed atomic.Time
closed atomic.Bool
mu sync.Mutex
}
const (
udpBufferSize = 16 * 1024
udpIdleTimeout = 5 * time.Minute // Longer timeout for game sessions
udpCleanupInterval = 1 * time.Minute
udpReadTimeout = 30 * time.Second
)
var bufPool = synk.NewBytesPool()
func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
dst, err := net.ResolveUDPAddr("udp", dstAddr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, err
}
return &UDPUDPStream{
laddr: laddr,
dst: dst,
conns: make(map[string]*udpUDPConn),
}, nil
}
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
listener, err := net.ListenUDP("udp", s.laddr)
if err != nil {
logErr(s, err, "failed to listen")
return
}
s.listener = listener
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)
go s.cleanUp(ctx)
}
func (s *UDPUDPStream) Close() error {
if s.closed.Swap(true) || s.listener == nil {
return nil
}
var wg sync.WaitGroup
s.mu.Lock()
for _, conn := range s.conns {
wg.Add(1)
go func(c *udpUDPConn) {
defer wg.Done()
c.Close()
}(conn)
}
clear(s.conns)
s.mu.Unlock()
wg.Wait()
return s.listener.Close()
}
func (s *UDPUDPStream) LocalAddr() net.Addr {
if s.listener == nil {
return s.laddr
}
return s.listener.LocalAddr()
}
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
e.Str("protocol", "udp-udp").Str("name", s.name).Str("dst", s.dst.String())
}
func (s *UDPUDPStream) listen(ctx context.Context) {
buf := bufPool.GetSized(udpBufferSize)
defer bufPool.Put(buf)
for {
select {
case <-ctx.Done():
return
default:
n, srcAddr, err := s.listener.ReadFromUDP(buf)
if err != nil {
if s.closed.Load() {
return
}
logErr(s, err, "failed to read from listener")
continue
}
logDebugf(s, "read %d bytes from %s", n, srcAddr)
if s.onRead != nil {
if err := s.onRead(ctx); err != nil {
logErr(s, err, "failed to on read")
continue
}
}
// Get or create connection, passing the initial data
go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n]))
}
}
}
func (s *UDPUDPStream) getOrCreateConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) {
key := srcAddr.String()
s.mu.Lock()
if conn, ok := s.conns[key]; ok {
s.mu.Unlock()
// Forward packet for existing connection
go conn.forwardToDestination(initialData)
return
}
defer s.mu.Unlock()
// Create new connection with initial data
conn, ok := s.createConnection(ctx, srcAddr, initialData)
if ok && !conn.closed.Load() {
s.conns[key] = conn
}
}
func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) (*udpUDPConn, bool) {
// Apply pre-dial if configured
if s.preDial != nil {
if err := s.preDial(ctx); err != nil {
logErr(s, err, "failed to pre-dial")
return nil, false
}
}
// Create UDP connection to destination
dstConn, err := net.DialUDP("udp", nil, s.dst)
if err != nil {
logErr(s, err, "failed to dial dst")
return nil, false
}
conn := &udpUDPConn{
srcAddr: srcAddr,
dstConn: dstConn,
listener: s.listener,
}
conn.lastUsed.Store(time.Now())
// Send initial data before starting response handler
if !conn.forwardToDestination(initialData) {
dstConn.Close()
return nil, false
}
// Start response handler after initial data is sent
go conn.handleResponses(ctx)
logDebugf(s, "created new connection from %s", srcAddr.String())
return conn, true
}
func (conn *udpUDPConn) MarshalZerologObject(e *zerolog.Event) {
e.Stringer("src", conn.srcAddr).Stringer("dst", conn.dstConn.RemoteAddr())
}
func (conn *udpUDPConn) handleResponses(ctx context.Context) {
buf := bufPool.GetSized(udpBufferSize)
defer bufPool.Put(buf)
defer conn.Close()
for {
if conn.closed.Load() {
return
}
select {
case <-ctx.Done():
return
default:
// Set a reasonable timeout for reads
_ = conn.dstConn.SetReadDeadline(time.Now().Add(udpReadTimeout))
n, err := conn.dstConn.Read(buf)
if err != nil {
if !conn.closed.Load() {
logErr(conn, err, "failed to read from dst")
}
return
}
// Clear deadline after successful read
_ = conn.dstConn.SetReadDeadline(time.Time{})
// Forward response back to client using the listener
_, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr)
if err != nil {
if !conn.closed.Load() {
logErrf(conn, err, "failed to write %d bytes to client", n)
}
return
}
conn.lastUsed.Store(time.Now())
logDebugf(conn, "forwarded response to client, %d bytes", n)
}
}
}
func (s *UDPUDPStream) cleanUp(ctx context.Context) {
s.cleanUpTicker = time.NewTicker(udpCleanupInterval)
defer s.cleanUpTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-s.cleanUpTicker.C:
s.mu.Lock()
conns := maps.Clone(s.conns)
s.mu.Unlock()
removed := []string(nil)
for key, conn := range conns {
if conn.Expired() {
conn.Close()
removed = append(removed, key)
}
}
s.mu.Lock()
for _, key := range removed {
logDebugf(s, "cleaning up expired connection: %s", key)
delete(s.conns, key)
}
s.mu.Unlock()
}
}
}
func (conn *udpUDPConn) forwardToDestination(data []byte) bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.closed.Load() {
return false
}
_, err := conn.dstConn.Write(data)
if err != nil {
logErrf(conn, err, "failed to write %d bytes to dst", len(data))
return false
}
conn.lastUsed.Store(time.Now())
logDebugf(conn, "forwarded %d bytes to dst", len(data))
return true
}
func (conn *udpUDPConn) Expired() bool {
return time.Since(conn.lastUsed.Load()) > udpIdleTimeout
}
func (conn *udpUDPConn) Close() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.closed.Load() {
return
}
conn.closed.Store(true)
conn.dstConn.Close()
conn.dstConn = nil
}

View file

@ -1,129 +0,0 @@
package route
import (
"errors"
"fmt"
"io"
"net"
"time"
"github.com/yusing/go-proxy/internal/net/types"
U "github.com/yusing/go-proxy/internal/utils"
)
type (
Stream struct {
*StreamRoute
listener types.StreamListener
targetAddr net.Addr
}
)
const (
streamFirstConnBufferSize = 128
streamDialTimeout = 5 * time.Second
)
func NewStream(base *StreamRoute) *Stream {
return &Stream{
StreamRoute: base,
}
}
func (stream *Stream) Addr() net.Addr {
if stream.listener == nil {
panic("listener is nil")
}
return stream.listener.Addr()
}
func (stream *Stream) Setup() error {
var lcfg net.ListenConfig
var err error
ctx := stream.task.Context()
switch stream.Scheme {
case "tcp":
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.ProxyURL.Host)
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(ctx, "tcp", stream.LisURL.Host)
if err != nil {
return err
}
// in case ListeningPort was zero, get the actual port
stream.Port.Listening = tcpListener.Addr().(*net.TCPAddr).Port
stream.listener = types.NetListener(tcpListener)
case "udp":
stream.targetAddr, err = net.ResolveUDPAddr("udp", stream.ProxyURL.Host)
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.LisURL.Host)
if err != nil {
return err
}
udpConn, ok := udpListener.(*net.UDPConn)
if !ok {
udpListener.Close()
return errors.New("udp listener is not *net.UDPConn")
}
stream.Port.Listening = udpConn.LocalAddr().(*net.UDPAddr).Port
stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
default:
panic("should not reach here")
}
return nil
}
func (stream *Stream) Accept() (conn types.StreamConn, err error) {
if stream.listener == nil {
return nil, errors.New("listener is nil")
}
// prevent Accept from blocking forever
done := make(chan struct{})
go func() {
conn, err = stream.listener.Accept()
close(done)
}()
select {
case <-stream.task.Context().Done():
stream.Close()
return nil, stream.task.Context().Err()
case <-done:
return conn, nil
}
}
func (stream *Stream) Handle(conn types.StreamConn) error {
switch conn := conn.(type) {
case *UDPConn:
switch stream := stream.listener.(type) {
case *UDPForwarder:
return stream.Handle(conn)
default:
return fmt.Errorf("unexpected listener type: %T", stream)
}
case io.ReadWriteCloser:
dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
if err != nil {
return err
}
defer dstConn.Close()
defer conn.Close()
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
return pipe.Start()
default:
return fmt.Errorf("unexpected conn type: %T", conn)
}
}
func (stream *Stream) Close() error {
return stream.listener.Close()
}

View file

@ -1,204 +0,0 @@
package route
import (
"context"
"fmt"
"net"
"sync"
"github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/net/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
UDPForwarder struct {
ctx context.Context
forwarder *net.UDPConn
dstAddr net.Addr
connMap F.Map[string, *UDPConn]
mu sync.Mutex
}
UDPConn struct {
srcAddr *net.UDPAddr
conn net.Conn
buf *UDPBuf
}
UDPBuf struct {
data, oob []byte
n, oobn int
}
)
const udpConnBufferSize = 4096
func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
forwarder: forwarder,
dstAddr: dstAddr,
connMap: F.NewMapOf[string, *UDPConn](),
}
}
func newUDPBuf() *UDPBuf {
return &UDPBuf{
data: make([]byte, udpConnBufferSize),
oob: make([]byte, udpConnBufferSize),
}
}
func (conn *UDPConn) DstAddrString() string {
return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String()
}
func (w *UDPForwarder) Addr() net.Addr {
return w.forwarder.LocalAddr()
}
func (w *UDPForwarder) Accept() (types.StreamConn, error) {
buf := newUDPBuf()
addr, err := w.readFromListener(buf)
if err != nil {
return nil, err
}
return &UDPConn{
srcAddr: addr,
buf: buf,
}, nil
}
func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) {
switch dstAddr := w.dstAddr.(type) {
case *net.UDPAddr:
var laddr *net.UDPAddr
if dstAddr.IP.IsLoopback() {
laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:")
}
dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr)
case *net.TCPAddr:
dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr)
default:
err = fmt.Errorf("unsupported network %s", w.dstAddr.Network())
}
return
}
func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) {
buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob)
if err == nil {
log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn)
}
return
}
func (conn *UDPConn) read() (err error) {
switch dstConn := conn.conn.(type) {
case *net.UDPConn:
conn.buf.n, conn.buf.oobn, _, _, err = dstConn.ReadMsgUDP(conn.buf.data, conn.buf.oob)
default:
conn.buf.n, err = dstConn.Read(conn.buf.data[:conn.buf.n])
conn.buf.oobn = 0
}
if err == nil {
log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
return
}
func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) {
buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr)
if err == nil {
log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn)
}
return
}
func (conn *UDPConn) write() (err error) {
switch dstConn := conn.conn.(type) {
case *net.UDPConn:
conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil)
if err == nil {
log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn)
}
default:
_, err = dstConn.Write(conn.buf.data[:conn.buf.n])
if err == nil {
log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n)
}
}
return
}
func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) {
w.mu.Lock()
defer w.mu.Unlock()
dst, ok := w.connMap.Load(key)
if !ok {
var err error
dst = conn
dst.conn, err = w.dialDst()
if err != nil {
return nil, err
}
if err := dst.write(); err != nil {
dst.conn.Close()
return nil, err
}
w.connMap.Store(key, dst)
} else {
conn.conn = dst.conn
if err := conn.write(); err != nil {
w.connMap.Delete(key)
dst.conn.Close()
return nil, err
}
}
return dst, nil
}
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
conn, ok := streamConn.(*UDPConn)
if !ok {
panic("unexpected conn type")
}
key := conn.srcAddr.String()
dst, err := w.getInitConn(conn, key)
if err != nil {
return err
}
for {
select {
case <-w.ctx.Done():
return nil
default:
if err := dst.read(); err != nil {
w.connMap.Delete(key)
dst.conn.Close()
return err
}
if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil {
return err
}
}
}
}
func (w *UDPForwarder) Close() error {
errs := gperr.NewBuilder("errors closing udp conn")
w.mu.Lock()
defer w.mu.Unlock()
w.connMap.RangeAll(func(key string, conn *UDPConn) {
errs.Add(conn.conn.Close())
})
w.connMap.Clear()
errs.Add(w.forwarder.Close())
return errs.Error()
}