package idlewatcher

import (
	"context"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	"github.com/docker/docker/api/types/container"
	"github.com/sirupsen/logrus"
	D "github.com/yusing/go-proxy/internal/docker"
	E "github.com/yusing/go-proxy/internal/error"
	P "github.com/yusing/go-proxy/internal/proxy"
	PT "github.com/yusing/go-proxy/internal/proxy/fields"
	W "github.com/yusing/go-proxy/internal/watcher"
)

type (
	watcher struct {
		*P.ReverseProxyEntry

		client D.Client

		ready        atomic.Bool  // whether the site is ready to accept connection
		stopByMethod StopCallback // send a docker command w.r.t. `stop_method`

		wakeCh   chan struct{}
		wakeDone chan E.NestedError

		ctx      context.Context
		cancel   context.CancelFunc
		refCount *sync.WaitGroup

		l logrus.FieldLogger
	}

	WakeDone     <-chan error
	WakeFunc     func() WakeDone
	StopCallback func() E.NestedError
)

var (
	mainLoopCtx    context.Context
	mainLoopCancel context.CancelFunc
	mainLoopWg     sync.WaitGroup

	watcherMap   = make(map[string]*watcher)
	watcherMapMu sync.Mutex

	newWatcherCh = make(chan *watcher)

	logger = logrus.WithField("module", "idle_watcher")
)

func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) {
	failure := E.Failure("idle_watcher register")

	if entry.IdleTimeout == 0 {
		return nil, failure.With(E.Invalid("idle_timeout", 0))
	}

	watcherMapMu.Lock()
	defer watcherMapMu.Unlock()

	key := entry.ContainerID

	if w, ok := watcherMap[key]; ok {
		w.refCount.Add(1)
		w.ReverseProxyEntry = entry
		return w, nil
	}

	client, err := D.ConnectClient(entry.DockerHost)
	if err.HasError() {
		return nil, failure.With(err)
	}

	w := &watcher{
		ReverseProxyEntry: entry,
		client:            client,
		refCount:          &sync.WaitGroup{},
		wakeCh:            make(chan struct{}),
		wakeDone:          make(chan E.NestedError),
		l:                 logger.WithField("container", entry.ContainerName),
	}
	w.refCount.Add(1)
	w.stopByMethod = w.getStopCallback()

	watcherMap[key] = w

	go func() {
		newWatcherCh <- w
	}()

	return w, nil
}

func Unregister(entry *P.ReverseProxyEntry) {
	if w, ok := watcherMap[entry.ContainerID]; ok {
		w.refCount.Add(-1)
	}
}

func Start() {
	logger.Debug("started")
	defer logger.Debug("stopped")

	mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background())

	for {
		select {
		case <-mainLoopCtx.Done():
			return
		case w := <-newWatcherCh:
			w.l.Debug("registered")
			mainLoopWg.Add(1)
			go func() {
				w.watchUntilCancel()
				w.refCount.Wait() // wait for 0 ref count

				w.client.Close()
				delete(watcherMap, w.ContainerID)
				w.l.Debug("unregistered")
				mainLoopWg.Done()
			}()
		}
	}
}

func Stop() {
	mainLoopCancel()
	mainLoopWg.Wait()
}

func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper {
	return roundTripper{patched: func(r *http.Request) (*http.Response, error) {
		return w.roundTrip(rtp.RoundTrip, r)
	}}
}

func (w *watcher) containerStop() error {
	return w.client.ContainerStop(w.ctx, w.ContainerID, container.StopOptions{
		Signal:  string(w.StopSignal),
		Timeout: &w.StopTimeout})
}

func (w *watcher) containerPause() error {
	return w.client.ContainerPause(w.ctx, w.ContainerID)
}

func (w *watcher) containerKill() error {
	return w.client.ContainerKill(w.ctx, w.ContainerID, string(w.StopSignal))
}

func (w *watcher) containerUnpause() error {
	return w.client.ContainerUnpause(w.ctx, w.ContainerID)
}

func (w *watcher) containerStart() error {
	return w.client.ContainerStart(w.ctx, w.ContainerID, container.StartOptions{})
}

func (w *watcher) containerStatus() (string, E.NestedError) {
	json, err := w.client.ContainerInspect(w.ctx, w.ContainerID)
	if err != nil {
		return "", E.FailWith("inspect container", err)
	}
	return json.State.Status, nil
}

func (w *watcher) wakeIfStopped() E.NestedError {
	status, err := w.containerStatus()

	if err.HasError() {
		return err
	}
	// "created", "running", "paused", "restarting", "removing", "exited", or "dead"
	switch status {
	case "exited", "dead":
		return E.From(w.containerStart())
	case "paused":
		return E.From(w.containerUnpause())
	case "running":
		return nil
	default:
		return E.Unexpected("container state", status)
	}
}

func (w *watcher) getStopCallback() StopCallback {
	var cb func() error
	switch w.StopMethod {
	case PT.StopMethodPause:
		cb = w.containerPause
	case PT.StopMethodStop:
		cb = w.containerStop
	case PT.StopMethodKill:
		cb = w.containerKill
	default:
		panic("should not reach here")
	}
	return func() E.NestedError {
		status, err := w.containerStatus()
		if err.HasError() {
			return err
		}
		if status != "running" {
			return nil
		}
		return E.From(cb())
	}
}

func (w *watcher) watchUntilCancel() {
	defer close(w.wakeCh)

	w.ctx, w.cancel = context.WithCancel(context.Background())

	dockerWatcher := W.NewDockerWatcherWithClient(w.client)
	dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{
		Filters: W.NewDockerFilter(
			W.DockerFilterContainer,
			W.DockerrFilterContainer(w.ContainerID),
			W.DockerFilterStart,
			W.DockerFilterStop,
			W.DockerFilterDie,
			W.DockerFilterKill,
			W.DockerFilterPause,
			W.DockerFilterUnpause,
		),
	})

	ticker := time.NewTicker(w.IdleTimeout)
	defer ticker.Stop()

	for {
		select {
		case <-mainLoopCtx.Done():
			w.cancel()
		case <-w.ctx.Done():
			w.l.Debug("stopped")
			return
		case err := <-dockerEventErrCh:
			if err != nil && err.IsNot(context.Canceled) {
				w.l.Error(E.FailWith("docker watcher", err))
			}
		case e := <-dockerEventCh:
			switch {
			// create / start / unpause
			case e.Action.IsContainerWake():
				ticker.Reset(w.IdleTimeout)
				w.l.Info(e)
			default: // stop / pause / kill
				ticker.Stop()
				w.ready.Store(false)
				w.l.Info(e)
			}
		case <-ticker.C:
			w.l.Debug("idle timeout")
			ticker.Stop()
			if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) {
				w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod))
			}
		case <-w.wakeCh:
			w.l.Debug("wake signal received")
			ticker.Reset(w.IdleTimeout)
			err := w.wakeIfStopped()
			if err != nil && err.IsNot(context.Canceled) {
				w.l.Error(E.FailWith("wake", err))
			}
			select {
			case w.wakeDone <- err: // this is passed to roundtrip
			default:
			}
		}
	}
}