mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-10 21:32:33 +02:00
162 lines
3 KiB
Go
162 lines
3 KiB
Go
package stream
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
|
|
"github.com/rs/zerolog"
|
|
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
|
"github.com/yusing/go-proxy/internal/utils"
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
type TCPTCPStream struct {
|
|
listener *net.TCPListener
|
|
laddr *net.TCPAddr
|
|
dst *net.TCPAddr
|
|
|
|
preDial nettypes.HookFunc
|
|
onRead nettypes.HookFunc
|
|
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
|
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
laddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
|
|
}
|
|
|
|
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
|
listener, err := net.ListenTCP("tcp", s.laddr)
|
|
if err != nil {
|
|
logErr(s, err, "failed to listen")
|
|
return
|
|
}
|
|
s.listener = listener
|
|
s.preDial = preDial
|
|
s.onRead = onRead
|
|
go s.listen(ctx)
|
|
}
|
|
|
|
func (s *TCPTCPStream) Close() error {
|
|
if s.closed.Swap(true) || s.listener == nil {
|
|
return nil
|
|
}
|
|
return s.listener.Close()
|
|
}
|
|
|
|
func (s *TCPTCPStream) LocalAddr() net.Addr {
|
|
if s.listener == nil {
|
|
return s.laddr
|
|
}
|
|
return s.listener.Addr()
|
|
}
|
|
|
|
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
|
|
e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String())
|
|
}
|
|
|
|
func (s *TCPTCPStream) listen(ctx context.Context) {
|
|
for {
|
|
if s.closed.Load() {
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
if s.closed.Load() {
|
|
return
|
|
}
|
|
logErr(s, err, "failed to accept connection")
|
|
continue
|
|
}
|
|
if s.onRead != nil {
|
|
if err := s.onRead(ctx); err != nil {
|
|
logErr(s, err, "failed to on read")
|
|
continue
|
|
}
|
|
}
|
|
go s.handle(ctx, conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
if s.preDial != nil {
|
|
if err := s.preDial(ctx); err != nil {
|
|
if !s.closed.Load() {
|
|
logErr(s, err, "failed to pre-dial")
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
if s.closed.Load() {
|
|
return
|
|
}
|
|
|
|
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
|
if err != nil {
|
|
if !s.closed.Load() {
|
|
logErr(s, err, "failed to dial destination")
|
|
}
|
|
return
|
|
}
|
|
defer dstConn.Close()
|
|
|
|
if s.closed.Load() {
|
|
return
|
|
}
|
|
|
|
src := conn
|
|
dst := net.Conn(dstConn)
|
|
if s.onRead != nil {
|
|
src = &wrapperConn{
|
|
Conn: conn,
|
|
ctx: ctx,
|
|
onRead: s.onRead,
|
|
}
|
|
dst = &wrapperConn{
|
|
Conn: dstConn,
|
|
ctx: ctx,
|
|
onRead: s.onRead,
|
|
}
|
|
}
|
|
|
|
pipe := utils.NewBidirectionalPipe(ctx, src, dst)
|
|
if err := pipe.Start(); err != nil && !s.closed.Load() {
|
|
logErr(s, err, "error in bidirectional pipe")
|
|
}
|
|
}
|
|
|
|
type wrapperConn struct {
|
|
net.Conn
|
|
ctx context.Context
|
|
onRead nettypes.HookFunc
|
|
}
|
|
|
|
func (w *wrapperConn) Read(b []byte) (n int, err error) {
|
|
n, err = w.Conn.Read(b)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if w.onRead != nil {
|
|
if err = w.onRead(w.ctx); err != nil {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|