mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 04:42:33 +02:00
fix stream task stuck on reload and udp mutex not unlocked properly
This commit is contained in:
parent
e04080bf1c
commit
5e2ce9e1e6
3 changed files with 43 additions and 28 deletions
|
@ -87,12 +87,6 @@ func (r *StreamRoute) Start(parent task.Parent) E.Error {
|
|||
return E.From(err)
|
||||
}
|
||||
|
||||
r.task.OnCancel("close_stream", func() {
|
||||
if err := r.Stream.Close(); err != nil {
|
||||
E.LogError("close stream failed", err, &r.l)
|
||||
}
|
||||
})
|
||||
|
||||
r.l.Info().
|
||||
Int("port", int(r.Port.ListeningPort)).
|
||||
Msg("listening")
|
||||
|
|
|
@ -43,16 +43,19 @@ func (stream *Stream) Setup() error {
|
|||
var lcfg net.ListenConfig
|
||||
var err error
|
||||
|
||||
ctx := stream.task.Context()
|
||||
|
||||
switch stream.Scheme.ListeningScheme {
|
||||
case "tcp":
|
||||
stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.URL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tcpListener, err := lcfg.Listen(stream.task.Context(), "tcp", stream.ListenURL.Host)
|
||||
tcpListener, err := lcfg.Listen(ctx, "tcp", stream.ListenURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// in case ListeningPort was zero, get the actual port
|
||||
stream.Port.ListeningPort = T.Port(tcpListener.Addr().(*net.TCPAddr).Port)
|
||||
stream.listener = types.NetListener(tcpListener)
|
||||
case "udp":
|
||||
|
@ -60,7 +63,7 @@ func (stream *Stream) Setup() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpListener, err := lcfg.ListenPacket(stream.task.Context(), "udp", stream.ListenURL.Host)
|
||||
udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.ListenURL.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -70,7 +73,7 @@ func (stream *Stream) Setup() error {
|
|||
return errors.New("udp listener is not *net.UDPConn")
|
||||
}
|
||||
stream.Port.ListeningPort = T.Port(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
stream.listener = NewUDPForwarder(stream.task.Context(), udpConn, stream.targetAddr)
|
||||
stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr)
|
||||
default:
|
||||
panic("should not reach here")
|
||||
}
|
||||
|
@ -78,11 +81,24 @@ func (stream *Stream) Setup() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (stream *Stream) Accept() (types.StreamConn, error) {
|
||||
func (stream *Stream) Accept() (conn types.StreamConn, err error) {
|
||||
if stream.listener == nil {
|
||||
return nil, errors.New("listener is nil")
|
||||
}
|
||||
return stream.listener.Accept()
|
||||
// prevent Accept from blocking forever
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
conn, err = stream.listener.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stream.task.Context().Done():
|
||||
stream.Close()
|
||||
return nil, stream.task.Context().Err()
|
||||
case <-done:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (stream *Stream) Handle(conn types.StreamConn) error {
|
||||
|
@ -95,14 +111,13 @@ func (stream *Stream) Handle(conn types.StreamConn) error {
|
|||
return fmt.Errorf("unexpected listener type: %T", stream)
|
||||
}
|
||||
case io.ReadWriteCloser:
|
||||
stream.task.OnCancel("close_conn", func() { conn.Close() })
|
||||
|
||||
dialer := &net.Dialer{Timeout: streamDialTimeout}
|
||||
dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstConn.Close()
|
||||
defer conn.Close()
|
||||
pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn)
|
||||
return pipe.Start()
|
||||
default:
|
||||
|
|
|
@ -48,10 +48,6 @@ func newUDPBuf() *UDPBuf {
|
|||
}
|
||||
}
|
||||
|
||||
func (conn *UDPConn) SrcAddrString() string {
|
||||
return conn.srcAddr.Network() + "://" + conn.srcAddr.String()
|
||||
}
|
||||
|
||||
func (conn *UDPConn) DstAddrString() string {
|
||||
return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String()
|
||||
}
|
||||
|
@ -135,25 +131,21 @@ func (conn *UDPConn) write() (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
|
||||
conn, ok := streamConn.(*UDPConn)
|
||||
if !ok {
|
||||
panic("unexpected conn type")
|
||||
}
|
||||
key := conn.srcAddr.String()
|
||||
|
||||
func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
dst, ok := w.connMap.Load(key)
|
||||
if !ok {
|
||||
var err error
|
||||
dst = conn
|
||||
dst.conn, err = w.dialDst()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err := dst.write(); err != nil {
|
||||
dst.conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
w.connMap.Store(key, dst)
|
||||
} else {
|
||||
|
@ -161,10 +153,24 @@ func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
|
|||
if err := conn.write(); err != nil {
|
||||
w.connMap.Delete(key)
|
||||
dst.conn.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (w *UDPForwarder) Handle(streamConn types.StreamConn) error {
|
||||
conn, ok := streamConn.(*UDPConn)
|
||||
if !ok {
|
||||
panic("unexpected conn type")
|
||||
}
|
||||
|
||||
key := conn.srcAddr.String()
|
||||
dst, err := w.getInitConn(conn, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
Loading…
Add table
Reference in a new issue