From 9a81b13b67f359991659bfbc4f6f4d4b32b9a554 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 21 Sep 2024 13:40:20 +0800 Subject: [PATCH] fixing tcp/udp error on closing --- src/route/stream_route.go | 39 ++++++++++++++++++++++----------------- src/route/tcp_route.go | 21 +++++++++++---------- src/route/udp_route.go | 8 +------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/route/stream_route.go b/src/route/stream_route.go index 56b1a83..8c492ab 100755 --- a/src/route/stream_route.go +++ b/src/route/stream_route.go @@ -1,6 +1,7 @@ package route import ( + "context" "fmt" "sync" "sync/atomic" @@ -15,8 +16,10 @@ type StreamRoute struct { P.StreamEntry StreamImpl `json:"-"` - wg sync.WaitGroup - stopCh chan struct{} + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + connCh chan any started atomic.Bool l logrus.FieldLogger @@ -36,8 +39,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { } base := &StreamRoute{ StreamEntry: *entry, - wg: sync.WaitGroup{}, - connCh: make(chan any), + connCh: make(chan any, 100), } if entry.Scheme.ListeningScheme.IsTCP() { base.StreamImpl = NewTCPRoute(base) @@ -54,9 +56,9 @@ func (r *StreamRoute) String() string { func (r *StreamRoute) Start() E.NestedError { if r.started.Load() { - return E.Invalid("state", "already started") + return nil } - r.stopCh = make(chan struct{}, 1) + r.ctx, r.cancel = context.WithCancel(context.Background()) r.wg.Wait() if err := r.Setup(); err != nil { return E.FailWith("setup", err) @@ -70,10 +72,10 @@ func (r *StreamRoute) Start() E.NestedError { func (r *StreamRoute) Stop() E.NestedError { if !r.started.Load() { - return E.Invalid("state", "not started") + return nil } l := r.l - close(r.stopCh) + r.cancel() r.CloseListeners() done := make(chan struct{}, 1) @@ -82,13 +84,16 @@ func (r *StreamRoute) Stop() E.NestedError { close(done) }() - select { - case <-done: - l.Info("stopped listening") - case <-time.After(streamStopListenTimeout): - l.Error("timed out waiting for connections") + timeout := time.After(streamStopListenTimeout) + for { + select { + case <-done: + l.Debug("stopped listening") + return nil + case <-timeout: + return E.FailedWhy("stop", "timed out") + } } - return nil } func (r *StreamRoute) grAcceptConnections() { @@ -96,13 +101,13 @@ func (r *StreamRoute) grAcceptConnections() { for { select { - case <-r.stopCh: + case <-r.ctx.Done(): return default: conn, err := r.Accept() if err != nil { select { - case <-r.stopCh: + case <-r.ctx.Done(): return default: r.l.Error(err) @@ -119,7 +124,7 @@ func (r *StreamRoute) grHandleConnections() { for { select { - case <-r.stopCh: + case <-r.ctx.Done(): return case conn := <-r.connCh: go func() { diff --git a/src/route/tcp_route.go b/src/route/tcp_route.go index 2baab41..b2fe35e 100755 --- a/src/route/tcp_route.go +++ b/src/route/tcp_route.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "sync" + "syscall" "time" U "github.com/yusing/go-proxy/utils" @@ -24,7 +25,6 @@ type TCPRoute struct { func NewTCPRoute(base *StreamRoute) StreamImpl { return &TCPRoute{ StreamRoute: base, - listener: nil, pipe: make(Pipes, 0), } } @@ -47,7 +47,7 @@ func (route *TCPRoute) Handle(c any) error { defer clientConn.Close() - ctx, cancel := context.WithTimeout(context.Background(), tcpDialTimeout) + ctx, cancel := context.WithTimeout(route.ctx, tcpDialTimeout) defer cancel() serverAddr := fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort) @@ -58,16 +58,10 @@ func (route *TCPRoute) Handle(c any) error { return err } - pipeCtx, pipeCancel := context.WithCancel(context.Background()) - go func() { - <-route.stopCh - pipeCancel() - }() - route.mu.Lock() defer route.mu.Unlock() - pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn) + pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn) route.pipe = append(route.pipe, pipe) return pipe.Start() } @@ -80,7 +74,14 @@ func (route *TCPRoute) CloseListeners() { route.listener = nil for _, pipe := range route.pipe { if err := pipe.Stop(); err != nil { - route.l.Error(err) + switch err { + // target closing connection + // TODO: handle this by fixing utils/io.go + case net.ErrClosed, syscall.EPIPE: + return + default: + route.l.Error(err) + } } } } diff --git a/src/route/udp_route.go b/src/route/udp_route.go index 8767712..1b9194e 100755 --- a/src/route/udp_route.go +++ b/src/route/udp_route.go @@ -1,7 +1,6 @@ package route import ( - "context" "fmt" "io" "net" @@ -84,15 +83,10 @@ func (route *UDPRoute) Accept() (any, error) { srcConn.Close() return nil, err } - pipeCtx, pipeCancel := context.WithCancel(context.Background()) - go func() { - <-route.stopCh - pipeCancel() - }() conn = &UDPConn{ srcConn, dstConn, - utils.NewBidirectionalPipe(pipeCtx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), + utils.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), } route.connMap[key] = conn }