fix stream task stuck on reload and udp mutex not unlocked properly

This commit is contained in:
yusing 2025-01-05 03:26:31 +08:00
parent e04080bf1c
commit 5e2ce9e1e6
3 changed files with 43 additions and 28 deletions

View file

@ -87,12 +87,6 @@ func (r *StreamRoute) Start(parent task.Parent) E.Error {
return E.From(err)
}
r.task.OnCancel("close_stream", func() {
if err := r.Stream.Close(); err != nil {
E.LogError("close stream failed", err, &r.l)
}
})
r.l.Info().
Int("port", int(r.Port.ListeningPort)).
Msg("listening")

View file

@ -43,16 +43,19 @@ func (stream *Stream) Setup() error {
var lcfg net.ListenConfig
var err error
ctx := stream.task.Context()
switch stream.Scheme.ListeningScheme {
case "tcp":
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host)
if err != nil {
return err
}
tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", stream.ListenURL.Host)
tcpListener, err := lcfg.Listen(ctx, "tcp", stream.ListenURL.Host)
if err != nil {
return err
}
// in case ListeningPort was zero, get the actual port
stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
stream.listener = types.NetListener(tcpListener)
case "udp":
@ -60,7 +63,7 @@ func (stream *Stream) Setup() error {
if err != nil {
return err
}
udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", stream.ListenURL.Host)
udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.ListenURL.Host)
if err != nil {
return err
}
@ -70,7 +73,7 @@ func (stream *Stream) Setup() error {
return errors.New("udp listener is not *net.UDPConn")
}
stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port)
stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr)
stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
default:
panic("should not reach here")
}
@ -78,11 +81,24 @@ func (stream *Stream) Setup() error {
return nil
}
func (stream *Stream) Accept() (types.StreamConn, error) {
func (stream *Stream) Accept() (conn types.StreamConn, err error) {
if stream.listener == nil {
return nil, errors.New("listener is nil")
}
return stream.listener.Accept()
// prevent Accept from blocking forever
done := make(chan struct{})
go func() {
conn, err = stream.listener.Accept()
close(done)
}()
select {
case <-stream.task.Context().Done():
stream.Close()
return nil, stream.task.Context().Err()
case <-done:
return conn, nil
}
}
func (stream *Stream) Handle(conn types.StreamConn) error {
@ -95,14 +111,13 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
return fmt.Errorf("unexpected listener type: %T", stream)
}
case io.ReadWriteCloser:
stream.task.OnCancel("close_conn", func() { conn.Close() })
dialer := &net.Dialer{Timeout: streamDialTimeout}
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
if err != nil {
return err
}
defer dstConn.Close()
defer conn.Close()
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
return pipe.Start()
default:

View file

@ -48,10 +48,6 @@ func newUDPBuf() *UDPBuf {
}
}
func (conn *UDPConn) SrcAddrString() string {
return conn.srcAddr.Network() + "://" + conn.srcAddr.String()
}
func (conn *UDPConn) DstAddrString() string {
return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String()
}
@ -135,25 +131,21 @@ func (conn *UDPConn) write() (err error) {
return nil
}
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
conn, ok := streamConn.(*UDPConn)
if !ok {
panic("unexpected conn type")
}
key := conn.srcAddr.String()
func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) {
w.mu.Lock()
defer w.mu.Unlock()
dst, ok := w.connMap.Load(key)
if !ok {
var err error
dst = conn
dst.conn, err = w.dialDst()
if err != nil {
return err
return nil, err
}
if err := dst.write(); err != nil {
dst.conn.Close()
return err
return nil, err
}
w.connMap.Store(key, dst)
} else {
@ -161,10 +153,24 @@ func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
if err := conn.write(); err != nil {
w.connMap.Delete(key)
dst.conn.Close()
return nil, err
}
}
return dst, nil
}
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
conn, ok := streamConn.(*UDPConn)
if !ok {
panic("unexpected conn type")
}
key := conn.srcAddr.String()
dst, err := w.getInitConn(conn, key)
if err != nil {
return err
}
}
w.mu.Unlock()
for {
select {