package stream

import (
	"context"
	"net"

	"github.com/rs/zerolog"
	nettypes "github.com/yusing/go-proxy/internal/net/types"
	"github.com/yusing/go-proxy/internal/utils"
	"go.uber.org/atomic"
)

type TCPTCPStream struct {
	listener *net.TCPListener
	laddr    *net.TCPAddr
	dst      *net.TCPAddr

	preDial nettypes.HookFunc
	onRead  nettypes.HookFunc

	closed atomic.Bool
}

func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
	dst, err := net.ResolveTCPAddr("tcp", dstAddr)
	if err != nil {
		return nil, err
	}
	laddr, err := net.ResolveTCPAddr("tcp", listenAddr)
	if err != nil {
		return nil, err
	}
	return &TCPTCPStream{laddr: laddr, dst: dst}, nil
}

func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
	listener, err := net.ListenTCP("tcp", s.laddr)
	if err != nil {
		logErr(s, err, "failed to listen")
		return
	}
	s.listener = listener
	s.preDial = preDial
	s.onRead = onRead
	go s.listen(ctx)
}

func (s *TCPTCPStream) Close() error {
	if s.closed.Swap(true) || s.listener == nil {
		return nil
	}
	return s.listener.Close()
}

func (s *TCPTCPStream) LocalAddr() net.Addr {
	if s.listener == nil {
		return s.laddr
	}
	return s.listener.Addr()
}

func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
	e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String())
}

func (s *TCPTCPStream) listen(ctx context.Context) {
	for {
		if s.closed.Load() {
			return
		}

		select {
		case <-ctx.Done():
			return
		default:
			conn, err := s.listener.Accept()
			if err != nil {
				if s.closed.Load() {
					return
				}
				logErr(s, err, "failed to accept connection")
				continue
			}
			if s.onRead != nil {
				if err := s.onRead(ctx); err != nil {
					logErr(s, err, "failed to on read")
					continue
				}
			}
			go s.handle(ctx, conn)
		}
	}
}

func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
	defer conn.Close()

	if s.preDial != nil {
		if err := s.preDial(ctx); err != nil {
			if !s.closed.Load() {
				logErr(s, err, "failed to pre-dial")
			}
			return
		}
	}

	if s.closed.Load() {
		return
	}

	dstConn, err := net.DialTCP("tcp", nil, s.dst)
	if err != nil {
		if !s.closed.Load() {
			logErr(s, err, "failed to dial destination")
		}
		return
	}
	defer dstConn.Close()

	if s.closed.Load() {
		return
	}

	src := conn
	dst := net.Conn(dstConn)
	if s.onRead != nil {
		src = &wrapperConn{
			Conn:   conn,
			ctx:    ctx,
			onRead: s.onRead,
		}
		dst = &wrapperConn{
			Conn:   dstConn,
			ctx:    ctx,
			onRead: s.onRead,
		}
	}

	pipe := utils.NewBidirectionalPipe(ctx, src, dst)
	if err := pipe.Start(); err != nil && !s.closed.Load() {
		logErr(s, err, "error in bidirectional pipe")
	}
}

type wrapperConn struct {
	net.Conn
	ctx    context.Context
	onRead nettypes.HookFunc
}

func (w *wrapperConn) Read(b []byte) (n int, err error) {
	n, err = w.Conn.Read(b)
	if err != nil {
		return
	}
	if w.onRead != nil {
		if err = w.onRead(w.ctx); err != nil {
			return
		}
	}
	return
}