GoDoxy/internal/route/stream/tcp_tcp.go

162 lines
3 KiB
Go

package stream
import (
"context"
"net"
"github.com/rs/zerolog"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
"go.uber.org/atomic"
)
type TCPTCPStream struct {
listener *net.TCPListener
laddr *net.TCPAddr
dst *net.TCPAddr
preDial nettypes.HookFunc
onRead nettypes.HookFunc
closed atomic.Bool
}
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
}
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
listener, err := net.ListenTCP("tcp", s.laddr)
if err != nil {
logErr(s, err, "failed to listen")
return
}
s.listener = listener
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)
}
func (s *TCPTCPStream) Close() error {
if s.closed.Swap(true) || s.listener == nil {
return nil
}
return s.listener.Close()
}
func (s *TCPTCPStream) LocalAddr() net.Addr {
if s.listener == nil {
return s.laddr
}
return s.listener.Addr()
}
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String())
}
func (s *TCPTCPStream) listen(ctx context.Context) {
for {
if s.closed.Load() {
return
}
select {
case <-ctx.Done():
return
default:
conn, err := s.listener.Accept()
if err != nil {
if s.closed.Load() {
return
}
logErr(s, err, "failed to accept connection")
continue
}
if s.onRead != nil {
if err := s.onRead(ctx); err != nil {
logErr(s, err, "failed to on read")
continue
}
}
go s.handle(ctx, conn)
}
}
}
func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
if s.preDial != nil {
if err := s.preDial(ctx); err != nil {
if !s.closed.Load() {
logErr(s, err, "failed to pre-dial")
}
return
}
}
if s.closed.Load() {
return
}
dstConn, err := net.DialTCP("tcp", nil, s.dst)
if err != nil {
if !s.closed.Load() {
logErr(s, err, "failed to dial destination")
}
return
}
defer dstConn.Close()
if s.closed.Load() {
return
}
src := conn
dst := net.Conn(dstConn)
if s.onRead != nil {
src = &wrapperConn{
Conn: conn,
ctx: ctx,
onRead: s.onRead,
}
dst = &wrapperConn{
Conn: dstConn,
ctx: ctx,
onRead: s.onRead,
}
}
pipe := utils.NewBidirectionalPipe(ctx, src, dst)
if err := pipe.Start(); err != nil && !s.closed.Load() {
logErr(s, err, "error in bidirectional pipe")
}
}
type wrapperConn struct {
net.Conn
ctx context.Context
onRead nettypes.HookFunc
}
func (w *wrapperConn) Read(b []byte) (n int, err error) {
n, err = w.Conn.Read(b)
if err != nil {
return
}
if w.onRead != nil {
if err = w.onRead(w.ctx); err != nil {
return
}
}
return
}