package task

import (
	"context"
	"errors"
	"fmt"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/rs/zerolog/log"
)

var ErrProgramExiting = errors.New("program exiting")

var root *Task

var closedCh = make(chan struct{})

func init() {
	close(closedCh)
	initRoot()
}

func initRoot() {
	ctx, cancel := context.WithCancelCause(context.Background())
	root = &Task{
		name:   "root",
		ctx:    ctx,
		cancel: cancel,
		done:   closedCh,
	}
	root.parent = root
}

func testCleanup() {
	root.cancel(nil)
	initRoot()
}

// RootTask returns a new Task with the given name, derived from the root context.
//
//go:inline
func RootTask(name string, needFinish bool) *Task {
	return root.Subtask(name, needFinish)
}

func RootContext() context.Context {
	return root.Context()
}

func RootContextCanceled() <-chan struct{} {
	return root.Context().Done()
}

func OnProgramExit(about string, fn func()) {
	root.OnCancel(about, fn)
}

// WaitExit waits for a signal to shutdown the program, and then waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func WaitExit(shutdownTimeout int) {
	sig := make(chan os.Signal, 1)
	signal.Notify(sig, syscall.SIGINT)
	signal.Notify(sig, syscall.SIGTERM)
	signal.Notify(sig, syscall.SIGHUP)

	// wait for signal
	<-sig

	// gracefully shutdown
	log.Info().Msg("shutting down")
	if err := gracefulShutdown(time.Second * time.Duration(shutdownTimeout)); err != nil {
		root.reportStucked()
	}
}

// gracefulShutdown waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func gracefulShutdown(timeout time.Duration) error {
	go root.Finish(ErrProgramExiting)
	if !root.waitFinish(timeout) {
		return context.DeadlineExceeded
	}
	return nil
}

func invokeWithRecover(cb *Callback) {
	defer func() {
		if err := recover(); err != nil {
			log.Err(fmtCause(err)).Str("callback", cb.about).Msg("panic")
			panicWithDebugStack()
		}
	}()
	cb.fn()
}

//go:inline
func fmtCause(cause any) error {
	switch cause := cause.(type) {
	case nil:
		return nil
	case error:
		return cause
	case string:
		return errors.New(cause)
	default:
		return fmt.Errorf("%v", cause)
	}
}