package route import ( "errors" "fmt" "io" "net" "time" "github.com/yusing/go-proxy/internal/net/types" T "github.com/yusing/go-proxy/internal/proxy/fields" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( UDPRoute struct { *StreamRoute connMap UDPConnMap listeningConn *net.UDPConn targetAddr *net.UDPAddr } UDPConn struct { key string src *net.UDPConn dst *net.UDPConn U.BidirectionalPipe } UDPConnMap = F.Map[string, *UDPConn] ) var NewUDPConnMap = F.NewMap[UDPConnMap] const udpBufferSize = 8192 func NewUDPRoute(base *StreamRoute) *UDPRoute { return &UDPRoute{ StreamRoute: base, connMap: NewUDPConnMap(), } } func (route *UDPRoute) Setup() error { laddr, err := net.ResolveUDPAddr(string(route.Scheme.ListeningScheme), fmt.Sprintf(":%v", route.Port.ListeningPort)) if err != nil { return err } source, err := net.ListenUDP(string(route.Scheme.ListeningScheme), laddr) if err != nil { return err } raddr, err := net.ResolveUDPAddr(string(route.Scheme.ProxyScheme), fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort)) if err != nil { source.Close() return err } //! this read the allocated listeningPort from original ':0' route.Port.ListeningPort = T.Port(source.LocalAddr().(*net.UDPAddr).Port) route.listeningConn = source route.targetAddr = raddr return nil } func (route *UDPRoute) Accept() (types.StreamConn, error) { in := route.listeningConn buffer := make([]byte, udpBufferSize) route.listeningConn.SetReadDeadline(time.Now().Add(time.Second)) nRead, srcAddr, err := in.ReadFromUDP(buffer) if err != nil { return nil, err } if nRead == 0 { return nil, io.ErrShortBuffer } key := srcAddr.String() conn, ok := route.connMap.Load(key) if !ok { srcConn, err := net.DialUDP("udp", nil, srcAddr) if err != nil { return nil, err } dstConn, err := net.DialUDP("udp", nil, route.targetAddr) if err != nil { srcConn.Close() return nil, err } conn = &UDPConn{ key, srcConn, dstConn, U.NewBidirectionalPipe(route.task.Context(), sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), } route.connMap.Store(key, conn) } _, err = conn.dst.Write(buffer[:nRead]) return conn, err } func (route *UDPRoute) Handle(c types.StreamConn) error { conn := c.(*UDPConn) err := conn.Start() route.connMap.Delete(conn.key) return err } func (route *UDPRoute) CloseListeners() { if route.listeningConn != nil { route.listeningConn.Close() } route.connMap.RangeAllParallel(func(_ string, conn *UDPConn) { if err := conn.Close(); err != nil { route.l.Errorf("error closing conn: %s", err) } }) route.connMap.Clear() } // Close implements types.StreamConn func (conn *UDPConn) Close() error { return errors.Join(conn.src.Close(), conn.dst.Close()) } // RemoteAddr implements types.StreamConn func (conn *UDPConn) RemoteAddr() net.Addr { return conn.src.RemoteAddr() } type sourceRWCloser struct { server *net.UDPConn *net.UDPConn } func (w sourceRWCloser) Write(p []byte) (int, error) { return w.server.WriteToUDP(p, w.RemoteAddr().(*net.UDPAddr)) // TODO: support non udp }