package v1

import (
	"bytes"
	"context"
	"io"
	"net/http"
	"sync"
	"time"

	"github.com/coder/websocket"
	"github.com/rs/zerolog"
	"github.com/yusing/go-proxy/internal/api/v1/utils"
	"github.com/yusing/go-proxy/internal/common"
	config "github.com/yusing/go-proxy/internal/config/types"
	"github.com/yusing/go-proxy/internal/logging"
	"github.com/yusing/go-proxy/internal/task"
	F "github.com/yusing/go-proxy/internal/utils/functional"
)

type logEntryRange struct {
	Start, End int
}

type memLogger struct {
	bytes.Buffer
	sync.RWMutex
	notifyLock sync.RWMutex
	connChans  F.Map[chan *logEntryRange, struct{}]

	bufPool sync.Pool // used in hook mode
}

type MemLogger interface {
	io.Writer
	// TODO: hook does not pass in fields, looking for a workaround to do server side log rendering
	zerolog.Hook
}

type buffer struct {
	data []byte
}

const (
	maxMemLogSize         = 16 * 1024
	truncateSize          = maxMemLogSize / 2
	initialWriteChunkSize = 4 * 1024
	hookModeBufSize       = 256
)

var memLoggerInstance = &memLogger{
	connChans: F.NewMapOf[chan *logEntryRange, struct{}](),
	bufPool: sync.Pool{
		New: func() any {
			return &buffer{
				data: make([]byte, 0, hookModeBufSize),
			}
		},
	},
}

func init() {
	if !common.EnableLogStreaming {
		return
	}
	memLoggerInstance.Grow(maxMemLogSize)

	if common.DebugMemLogger {
		ticker := time.NewTicker(1 * time.Second)

		go func() {
			defer ticker.Stop()

			for {
				select {
				case <-task.RootContextCanceled():
					return
				case <-ticker.C:
					logging.Info().Msgf("mem logger size: %d, active conns: %d",
						memLoggerInstance.Len(),
						memLoggerInstance.connChans.Size())
				}
			}
		}()
	}
}

func LogsWS() func(config config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
	return memLoggerInstance.ServeHTTP
}

func GetMemLogger() MemLogger {
	return memLoggerInstance
}

func (m *memLogger) truncateIfNeeded(n int) {
	m.RLock()
	needTruncate := m.Len()+n > maxMemLogSize
	m.RUnlock()

	if needTruncate {
		m.Lock()
		defer m.Unlock()
		needTruncate = m.Len()+n > maxMemLogSize
		if !needTruncate {
			return
		}

		m.Truncate(truncateSize)
	}
}

func (m *memLogger) notifyWS(pos, n int) {
	if m.connChans.Size() > 0 {
		timeout := time.NewTimer(1 * time.Second)
		defer timeout.Stop()

		m.notifyLock.RLock()
		defer m.notifyLock.RUnlock()
		m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
			select {
			case ch <- &logEntryRange{pos, pos + n}:
				return true
			case <-timeout.C:
				logging.Warn().Msg("mem logger: timeout logging to channel")
				return false
			}
		})
		return
	}
}

func (m *memLogger) writeBuf(b []byte) (pos int, err error) {
	m.Lock()
	defer m.Unlock()
	pos = m.Len()
	_, err = m.Buffer.Write(b)
	return
}

// Run implements zerolog.Hook.
func (m *memLogger) Run(e *zerolog.Event, level zerolog.Level, message string) {
	bufStruct := m.bufPool.Get().(*buffer)
	buf := bufStruct.data
	defer func() {
		bufStruct.data = bufStruct.data[:0]
		m.bufPool.Put(bufStruct)
	}()

	buf = logging.FormatLogEntryHTML(level, message, buf)
	n := len(buf)

	m.truncateIfNeeded(n)

	pos, err := m.writeBuf(buf)
	if err != nil {
		// not logging the error here, it will cause Run to be called again = infinite loop
		return
	}

	m.notifyWS(pos, n)
}

// Write implements io.Writer.
func (m *memLogger) Write(p []byte) (n int, err error) {
	n = len(p)
	m.truncateIfNeeded(n)

	pos, err := m.writeBuf(p)
	if err != nil {
		// not logging the error here, it will cause Run to be called again = infinite loop
		return
	}

	m.notifyWS(pos, n)
	return
}

func (m *memLogger) ServeHTTP(config config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
	conn, err := utils.InitiateWS(config, w, r)
	if err != nil {
		utils.HandleErr(w, r, err)
		return
	}

	logCh := make(chan *logEntryRange)
	m.connChans.Store(logCh, struct{}{})

	/* trunk-ignore(golangci-lint/errcheck) */
	defer func() {
		_ = conn.CloseNow()

		m.notifyLock.Lock()
		m.connChans.Delete(logCh)
		close(logCh)
		m.notifyLock.Unlock()
	}()

	if err := m.wsInitial(r.Context(), conn); err != nil {
		utils.HandleErr(w, r, err)
		return
	}

	m.wsStreamLog(r.Context(), conn, logCh)
}

func (m *memLogger) writeBytes(ctx context.Context, conn *websocket.Conn, b []byte) error {
	return conn.Write(ctx, websocket.MessageText, b)
}

func (m *memLogger) wsInitial(ctx context.Context, conn *websocket.Conn) error {
	m.Lock()
	defer m.Unlock()

	return m.writeBytes(ctx, conn, m.Buffer.Bytes())
}

func (m *memLogger) wsStreamLog(ctx context.Context, conn *websocket.Conn, ch <-chan *logEntryRange) {
	for {
		select {
		case <-ctx.Done():
			return
		case logRange := <-ch:
			m.RLock()
			msg := m.Buffer.Bytes()[logRange.Start:logRange.End]
			err := m.writeBytes(ctx, conn, msg)
			m.RUnlock()
			if err != nil {
				return
			}
		}
	}
}