From 5e2ce9e1e617e4d3c21b62f1a1e350e1bd2f6ecc Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 5 Jan 2025 03:26:31 +0800 Subject: [PATCH] fix stream task stuck on reload and udp mutex not unlocked properly --- internal/route/stream.go | 6 ------ internal/route/stream_impl.go | 29 +++++++++++++++++++------- internal/route/udp_forwarder.go | 36 +++++++++++++++++++-------------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/internal/route/stream.go b/internal/route/stream.go index 62335f0..1713fb5 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -87,12 +87,6 @@ func (r *StreamRoute) Start(parent task.Parent) E.Error { return E.From(err) } - r.task.OnCancel("close_stream", func() { - if err := r.Stream.Close(); err != nil { - E.LogError("close stream failed", err, &r.l) - } - }) - r.l.Info(). Int("port", int(r.Port.ListeningPort)). Msg("listening") diff --git a/internal/route/stream_impl.go b/internal/route/stream_impl.go index 5908bea..62321b1 100644 --- a/internal/route/stream_impl.go +++ b/internal/route/stream_impl.go @@ -43,16 +43,19 @@ func (stream *Stream) Setup() error { var lcfg net.ListenConfig var err error + ctx := stream.task.Context() + switch stream.Scheme.ListeningScheme { case "tcp": stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host) if err != nil { return err } - tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", stream.ListenURL.Host) + tcpListener, err := lcfg.Listen(ctx, "tcp", stream.ListenURL.Host) if err != nil { return err } + // in case ListeningPort was zero, get the actual port stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) stream.listener = types.NetListener(tcpListener) case "udp": @@ -60,7 +63,7 @@ func (stream *Stream) Setup() error { if err != nil { return err } - udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", stream.ListenURL.Host) + udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.ListenURL.Host) if err != nil { return err } @@ -70,7 +73,7 @@ func (stream *Stream) Setup() error { return errors.New("udp listener is not *net.UDPConn") } stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port) - stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr) + stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr) default: panic("should not reach here") } @@ -78,11 +81,24 @@ func (stream *Stream) Setup() error { return nil } -func (stream *Stream) Accept() (types.StreamConn, error) { +func (stream *Stream) Accept() (conn types.StreamConn, err error) { if stream.listener == nil { return nil, errors.New("listener is nil") } - return stream.listener.Accept() + // prevent Accept from blocking forever + done := make(chan struct{}) + go func() { + conn, err = stream.listener.Accept() + close(done) + }() + + select { + case <-stream.task.Context().Done(): + stream.Close() + return nil, stream.task.Context().Err() + case <-done: + return conn, nil + } } func (stream *Stream) Handle(conn types.StreamConn) error { @@ -95,14 +111,13 @@ func (stream *Stream) Handle(conn types.StreamConn) error { return fmt.Errorf("unexpected listener type: %T", stream) } case io.ReadWriteCloser: - stream.task.OnCancel("close_conn", func() { conn.Close() }) - dialer := &net.Dialer{Timeout: streamDialTimeout} dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) if err != nil { return err } defer dstConn.Close() + defer conn.Close() pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn) return pipe.Start() default: diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go index afb4c3e..84572e2 100644 --- a/internal/route/udp_forwarder.go +++ b/internal/route/udp_forwarder.go @@ -48,10 +48,6 @@ func newUDPBuf() *UDPBuf { } } -func (conn *UDPConn) SrcAddrString() string { - return conn.srcAddr.Network() + "://" + conn.srcAddr.String() -} - func (conn *UDPConn) DstAddrString() string { return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() } @@ -135,25 +131,21 @@ func (conn *UDPConn) write() (err error) { return nil } -func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { - conn, ok := streamConn.(*UDPConn) - if !ok { - panic("unexpected conn type") - } - key := conn.srcAddr.String() - +func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) { w.mu.Lock() + defer w.mu.Unlock() + dst, ok := w.connMap.Load(key) if !ok { var err error dst = conn dst.conn, err = w.dialDst() if err != nil { - return err + return nil, err } if err := dst.write(); err != nil { dst.conn.Close() - return err + return nil, err } w.connMap.Store(key, dst) } else { @@ -161,10 +153,24 @@ func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { if err := conn.write(); err != nil { w.connMap.Delete(key) dst.conn.Close() - return err + return nil, err } } - w.mu.Unlock() + + return dst, nil +} + +func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { + conn, ok := streamConn.(*UDPConn) + if !ok { + panic("unexpected conn type") + } + + key := conn.srcAddr.String() + dst, err := w.getInitConn(conn, key) + if err != nil { + return err + } for { select {