fix stream task stuck on reload and udp mutex not unlocked properly

This commit is contained in:
yusing 2025-01-05 03:26:31 +08:00
parent e04080bf1c
commit 5e2ce9e1e6
3 changed files with 43 additions and 28 deletions

View file

@ -87,12 +87,6 @@ func (r *StreamRoute) Start(parent task.Parent) E.Error {
return E.From(err) return E.From(err)
} }
r.task.OnCancel("close_stream", func() {
if err := r.Stream.Close(); err != nil {
E.LogError("close stream failed", err, &r.l)
}
})
r.l.Info(). r.l.Info().
Int("port", int(r.Port.ListeningPort)). Int("port", int(r.Port.ListeningPort)).
Msg("listening") Msg("listening")

View file

@ -43,16 +43,19 @@ func (stream *Stream) Setup() error {
var lcfg net.ListenConfig var lcfg net.ListenConfig
var err error var err error
ctx := stream.task.Context()
switch stream.Scheme.ListeningScheme { switch stream.Scheme.ListeningScheme {
case "tcp": case "tcp":
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host) stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host)
if err != nil { if err != nil {
return err return err
} }
tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", stream.ListenURL.Host) tcpListener, err := lcfg.Listen(ctx, "tcp", stream.ListenURL.Host)
if err != nil { if err != nil {
return err return err
} }
// in case ListeningPort was zero, get the actual port
stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port) stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
stream.listener = types.NetListener(tcpListener) stream.listener = types.NetListener(tcpListener)
case "udp": case "udp":
@ -60,7 +63,7 @@ func (stream *Stream) Setup() error {
if err != nil { if err != nil {
return err return err
} }
udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", stream.ListenURL.Host) udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.ListenURL.Host)
if err != nil { if err != nil {
return err return err
} }
@ -70,7 +73,7 @@ func (stream *Stream) Setup() error {
return errors.New("udp listener is not *net.UDPConn") return errors.New("udp listener is not *net.UDPConn")
} }
stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port) stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port)
stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr) stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
default: default:
panic("should not reach here") panic("should not reach here")
} }
@ -78,11 +81,24 @@ func (stream *Stream) Setup() error {
return nil return nil
} }
func (stream *Stream) Accept() (types.StreamConn, error) { func (stream *Stream) Accept() (conn types.StreamConn, err error) {
if stream.listener == nil { if stream.listener == nil {
return nil, errors.New("listener is nil") return nil, errors.New("listener is nil")
} }
return stream.listener.Accept() // prevent Accept from blocking forever
done := make(chan struct{})
go func() {
conn, err = stream.listener.Accept()
close(done)
}()
select {
case <-stream.task.Context().Done():
stream.Close()
return nil, stream.task.Context().Err()
case <-done:
return conn, nil
}
} }
func (stream *Stream) Handle(conn types.StreamConn) error { func (stream *Stream) Handle(conn types.StreamConn) error {
@ -95,14 +111,13 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
return fmt.Errorf("unexpected listener type: %T", stream) return fmt.Errorf("unexpected listener type: %T", stream)
} }
case io.ReadWriteCloser: case io.ReadWriteCloser:
stream.task.OnCancel("close_conn", func() { conn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout} dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
if err != nil { if err != nil {
return err return err
} }
defer dstConn.Close() defer dstConn.Close()
defer conn.Close()
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn) pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
return pipe.Start() return pipe.Start()
default: default:

View file

@ -48,10 +48,6 @@ func newUDPBuf() *UDPBuf {
} }
} }
func (conn *UDPConn) SrcAddrString() string {
return conn.srcAddr.Network() + "://" + conn.srcAddr.String()
}
func (conn *UDPConn) DstAddrString() string { func (conn *UDPConn) DstAddrString() string {
return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String()
} }
@ -135,25 +131,21 @@ func (conn *UDPConn) write() (err error) {
return nil return nil
} }
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) {
conn, ok := streamConn.(*UDPConn)
if !ok {
panic("unexpected conn type")
}
key := conn.srcAddr.String()
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock()
dst, ok := w.connMap.Load(key) dst, ok := w.connMap.Load(key)
if !ok { if !ok {
var err error var err error
dst = conn dst = conn
dst.conn, err = w.dialDst() dst.conn, err = w.dialDst()
if err != nil { if err != nil {
return err return nil, err
} }
if err := dst.write(); err != nil { if err := dst.write(); err != nil {
dst.conn.Close() dst.conn.Close()
return err return nil, err
} }
w.connMap.Store(key, dst) w.connMap.Store(key, dst)
} else { } else {
@ -161,10 +153,24 @@ func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
if err := conn.write(); err != nil { if err := conn.write(); err != nil {
w.connMap.Delete(key) w.connMap.Delete(key)
dst.conn.Close() dst.conn.Close()
return err return nil, err
} }
} }
w.mu.Unlock()
return dst, nil
}
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
conn, ok := streamConn.(*UDPConn)
if !ok {
panic("unexpected conn type")
}
key := conn.srcAddr.String()
dst, err := w.getInitConn(conn, key)
if err != nil {
return err
}
for { for {
select { select {