package route import ( "context" "fmt" "net" "sync" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/net/types" F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( UDPForwarder struct { ctx context.Context forwarder *net.UDPConn dstAddr net.Addr connMap F.Map[string, *UDPConn] mu sync.Mutex } UDPConn struct { srcAddr *net.UDPAddr conn net.Conn buf *UDPBuf } UDPBuf struct { data, oob []byte n, oobn int } ) const udpConnBufferSize = 4096 func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder { return &UDPForwarder{ ctx: ctx, forwarder: forwarder, dstAddr: dstAddr, connMap: F.NewMapOf[string, *UDPConn](), } } func newUDPBuf() *UDPBuf { return &UDPBuf{ data: make([]byte, udpConnBufferSize), oob: make([]byte, udpConnBufferSize), } } func (conn *UDPConn) DstAddrString() string { return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() } func (w *UDPForwarder) Addr() net.Addr { return w.forwarder.LocalAddr() } func (w *UDPForwarder) Accept() (types.StreamConn, error) { buf := newUDPBuf() addr, err := w.readFromListener(buf) if err != nil { return nil, err } return &UDPConn{ srcAddr: addr, buf: buf, }, nil } func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) { switch dstAddr := w.dstAddr.(type) { case *net.UDPAddr: var laddr *net.UDPAddr if dstAddr.IP.IsLoopback() { laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:") } dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr) case *net.TCPAddr: dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr) default: err = fmt.Errorf("unsupported network %s", w.dstAddr.Network()) } return } func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) { buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob) if err == nil { logger.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) } return } func (conn *UDPConn) read() (err error) { switch dstConn := conn.conn.(type) { case *net.UDPConn: conn.buf.n, conn.buf.oobn, _, _, err = dstConn.ReadMsgUDP(conn.buf.data, conn.buf.oob) default: conn.buf.n, err = dstConn.Read(conn.buf.data[:conn.buf.n]) conn.buf.oobn = 0 } if err == nil { logger.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) } return } func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) { buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr) if err == nil { logger.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) } return } func (conn *UDPConn) write() (err error) { switch dstConn := conn.conn.(type) { case *net.UDPConn: conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil) if err == nil { logger.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) } default: _, err = dstConn.Write(conn.buf.data[:conn.buf.n]) if err == nil { logger.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n) } } return } func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) { w.mu.Lock() defer w.mu.Unlock() dst, ok := w.connMap.Load(key) if !ok { var err error dst = conn dst.conn, err = w.dialDst() if err != nil { return nil, err } if err := dst.write(); err != nil { dst.conn.Close() return nil, err } w.connMap.Store(key, dst) } else { conn.conn = dst.conn if err := conn.write(); err != nil { w.connMap.Delete(key) dst.conn.Close() return nil, err } } return dst, nil } func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { conn, ok := streamConn.(*UDPConn) if !ok { panic("unexpected conn type") } key := conn.srcAddr.String() dst, err := w.getInitConn(conn, key) if err != nil { return err } for { select { case <-w.ctx.Done(): return nil default: if err := dst.read(); err != nil { w.connMap.Delete(key) dst.conn.Close() return err } if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil { return err } } } } func (w *UDPForwarder) Close() error { errs := E.NewBuilder("errors closing udp conn") w.mu.Lock() defer w.mu.Unlock() w.connMap.RangeAll(func(key string, conn *UDPConn) { errs.Add(conn.conn.Close()) }) w.connMap.Clear() errs.Add(w.forwarder.Close()) return errs.Error() }