package utils

import (
	"encoding/json"
	"errors"
	"os"
	"reflect"
	"runtime/debug"
	"strconv"
	"strings"
	"time"

	"github.com/go-playground/validator/v10"
	"github.com/yusing/go-proxy/internal/gperr"
	"github.com/yusing/go-proxy/internal/utils/functional"
	"github.com/yusing/go-proxy/internal/utils/strutils"
	"gopkg.in/yaml.v3"
)

type SerializedObject = map[string]any

type MapUnmarshaller interface {
	UnmarshalMap(m map[string]any) gperr.Error
}

var (
	ErrInvalidType           = gperr.New("invalid type")
	ErrNilValue              = gperr.New("nil")
	ErrUnsettable            = gperr.New("unsettable")
	ErrUnsupportedConversion = gperr.New("unsupported conversion")
	ErrUnknownField          = gperr.New("unknown field")
)

var (
	tagDeserialize = "deserialize" // `deserialize:"-"` to exclude from deserialization
	tagJSON        = "json"        // share between Deserialize and json.Marshal
	tagValidate    = "validate"    // uses go-playground/validator
	tagAliases     = "aliases"     // declare aliases for fields
)

var mapUnmarshalerType = reflect.TypeFor[MapUnmarshaller]()

var defaultValues = functional.NewMapOf[reflect.Type, func() any]()

func RegisterDefaultValueFactory[T any](factory func() *T) {
	t := reflect.TypeFor[T]()
	if t.Kind() == reflect.Ptr {
		panic("pointer of pointer")
	}
	if defaultValues.Has(t) {
		panic("default value for " + t.String() + " already registered")
	}
	defaultValues.Store(t, func() any { return factory() })
}

func New(t reflect.Type) reflect.Value {
	if dv, ok := defaultValues.Load(t); ok {
		return reflect.ValueOf(dv())
	}
	return reflect.New(t)
}

func extractFields(t reflect.Type) (all, anonymous []reflect.StructField) {
	for t.Kind() == reflect.Ptr {
		t = t.Elem()
	}
	if t.Kind() != reflect.Struct {
		return nil, nil
	}
	n := t.NumField()
	fields := make([]reflect.StructField, 0, n)
	for i := range n {
		field := t.Field(i)
		if !field.IsExported() {
			continue
		}
		if field.Tag.Get(tagDeserialize) == "-" {
			continue
		}
		if field.Anonymous {
			f1, f2 := extractFields(field.Type)
			fields = append(fields, f1...)
			anonymous = append(anonymous, field)
			anonymous = append(anonymous, f2...)
		} else {
			fields = append(fields, field)
		}
	}
	return fields, anonymous
}

func ValidateWithFieldTags(s any) gperr.Error {
	errs := gperr.NewBuilder("validate error")
	err := validate.Struct(s)
	var valErrs validator.ValidationErrors
	if errors.As(err, &valErrs) {
		for _, e := range valErrs {
			detail := e.ActualTag()
			if e.Param() != "" {
				detail += ":" + e.Param()
			}
			errs.Add(ErrValidationError.
				Subject(e.Namespace()).
				Withf("require %q", detail))
		}
	}
	return errs.Error()
}

func ValidateWithCustomValidator(v reflect.Value) gperr.Error {
	isStruct := false
	for {
		switch v.Kind() {
		case reflect.Pointer, reflect.Interface:
			if v.IsNil() {
				return gperr.Errorf("validate: v is %w", ErrNilValue)
			}
			if validate, ok := v.Interface().(CustomValidator); ok {
				return validate.Validate()
			}
			if isStruct {
				return nil
			}
			v = v.Elem()
		case reflect.Struct:
			if !v.CanAddr() {
				return nil
			}
			v = v.Addr()
			isStruct = true
		default:
			return nil
		}
	}
}

func dive(dst reflect.Value) (v reflect.Value, t reflect.Type, err gperr.Error) {
	dstT := dst.Type()
	for {
		switch dst.Kind() {
		case reflect.Pointer, reflect.Interface:
			if dst.IsNil() {
				if !dst.CanSet() {
					err = gperr.Errorf("dive: dst is %w and is not settable", ErrNilValue)
					return
				}
				dst.Set(New(dstT.Elem()))
			}
			dst = dst.Elem()
			dstT = dst.Type()
		case reflect.Struct:
			return dst, dstT, nil
		default:
			if dst.IsNil() {
				switch dst.Kind() {
				case reflect.Map:
					dst.Set(reflect.MakeMap(dstT))
				case reflect.Slice:
					dst.Set(reflect.MakeSlice(dstT, 0, 0))
				default:
					err = gperr.Errorf("deserialize: %w for dst %s", ErrInvalidType, dstT.String())
					return
				}
			}
			return dst, dstT, nil
		}
	}
}

// MapUnmarshalValidate takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// MapUnmarshalValidate ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct , and implements the MapUnmarshaller interface,
// the UnmarshalMap method will be called.
//
// If the target value is a struct, but does not implements the MapUnmarshaller interface,
// the SerializedObject will be deserialized into the struct fields and validate if needed.
//
// If the target value is a map[string]any the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func MapUnmarshalValidate(src SerializedObject, dst any) (err gperr.Error) {
	dstV := reflect.ValueOf(dst)
	dstT := dstV.Type()

	if src == nil {
		if dstV.CanSet() {
			dstV.Set(reflect.Zero(dstT))
			return nil
		}
		return gperr.Errorf("deserialize: src is %w and dst is not settable\n%s", ErrNilValue, debug.Stack())
	}

	if dstT.Implements(mapUnmarshalerType) {
		dstV, _, err = dive(dstV)
		if err != nil {
			return err
		}
		return dstV.Addr().Interface().(MapUnmarshaller).UnmarshalMap(src)
	}

	dstV, dstT, err = dive(dstV)
	if err != nil {
		return err
	}

	// convert data fields to lower no-snake
	// convert target fields to lower no-snake
	// then check if the field of data is in the target

	errs := gperr.NewBuilder("deserialize error")

	switch dstV.Kind() {
	case reflect.Struct, reflect.Interface:
		hasValidateTag := false
		mapping := make(map[string]reflect.Value)
		fields, anonymous := extractFields(dstT)
		for _, anon := range anonymous {
			if field := dstV.FieldByName(anon.Name); field.Kind() == reflect.Ptr && field.IsNil() {
				field.Set(New(anon.Type.Elem()))
			}
		}
		for _, field := range fields {
			var key string
			if jsonTag, ok := field.Tag.Lookup(tagJSON); ok {
				if jsonTag == "-" {
					continue
				}
				key = strutils.CommaSeperatedList(jsonTag)[0]
			} else {
				key = field.Name
			}
			key = strutils.ToLowerNoSnake(key)
			mapping[key] = dstV.FieldByName(field.Name)

			if !hasValidateTag {
				_, hasValidateTag = field.Tag.Lookup(tagValidate)
			}

			aliases, ok := field.Tag.Lookup(tagAliases)
			if ok {
				for _, alias := range strutils.CommaSeperatedList(aliases) {
					mapping[alias] = dstV.FieldByName(field.Name)
				}
			}
		}
		for k, v := range src {
			if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
				err := Convert(reflect.ValueOf(v), field)
				if err != nil {
					errs.Add(err.Subject(k))
				}
			} else {
				errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
			}
		}
		if hasValidateTag {
			errs.Add(ValidateWithFieldTags(dstV.Interface()))
		}
		if err := ValidateWithCustomValidator(dstV); err != nil {
			errs.Add(err)
		}
		return errs.Error()
	case reflect.Map:
		for k, v := range src {
			mapVT := dstT.Elem()
			tmp := New(mapVT).Elem()
			err := Convert(reflect.ValueOf(v), tmp)
			if err != nil {
				errs.Add(err.Subject(k))
				continue
			}
			if err := ValidateWithCustomValidator(tmp.Addr()); err != nil {
				errs.Add(err.Subject(k))
			} else {
				dstV.SetMapIndex(reflect.ValueOf(k), tmp)
			}
		}
		if err := ValidateWithCustomValidator(dstV); err != nil {
			errs.Add(err)
		}
		return errs.Error()
	default:
		return ErrUnsupportedConversion.Subject("mapping to " + dstT.String() + " ")
	}
}

func isIntFloat(t reflect.Kind) bool {
	return t >= reflect.Bool && t <= reflect.Float64
}

// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
// If src is a slice, each of its elements are converted and stored in dst.
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
//
// If dst is not settable, an error is returned.
// If src cannot be converted to dst, an error is returned.
// If any error occurs during conversion (e.g. deserialization), it is returned.
//
// Returns:
//   - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
	if !dst.IsValid() {
		return gperr.Errorf("convert: dst is %w", ErrNilValue)
	}

	if !src.IsValid() {
		if dst.CanSet() {
			dst.Set(reflect.Zero(dst.Type()))
			return nil
		}
		return gperr.Errorf("convert: src is %w", ErrNilValue)
	}

	srcT := src.Type()
	dstT := dst.Type()

	if src.Kind() == reflect.Interface {
		src = src.Elem()
		srcT = src.Type()
	}

	if !dst.CanSet() {
		return ErrUnsettable.Subject(dstT.String())
	}

	if dst.Kind() == reflect.Pointer {
		if dst.IsNil() {
			dst.Set(New(dstT.Elem()))
		}
		dst = dst.Elem()
		dstT = dst.Type()
	}

	srcKind := srcT.Kind()

	switch {
	case srcT.AssignableTo(dstT):
		dst.Set(src)
		return nil
	case srcKind == reflect.String:
		if convertible, err := ConvertString(src.String(), dst); convertible {
			return err
		}
	case isIntFloat(srcKind):
		var strV string
		switch {
		case src.CanInt():
			strV = strconv.FormatInt(src.Int(), 10)
		case srcKind == reflect.Bool:
			strV = strconv.FormatBool(src.Bool())
		case src.CanUint():
			strV = strconv.FormatUint(src.Uint(), 10)
		case src.CanFloat():
			strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
		}
		if convertible, err := ConvertString(strV, dst); convertible {
			return err
		}
	case srcKind == reflect.Map:
		if src.Len() == 0 {
			return nil
		}
		obj, ok := src.Interface().(SerializedObject)
		if !ok {
			return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
		}
		return MapUnmarshalValidate(obj, dst.Addr().Interface())
	case srcKind == reflect.Slice:
		if src.Len() == 0 {
			return nil
		}
		if dstT.Kind() != reflect.Slice {
			return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
		}
		sliceErrs := gperr.NewBuilder("slice conversion errors")
		newSlice := reflect.MakeSlice(dstT, src.Len(), src.Len())
		i := 0
		for j, v := range src.Seq2() {
			tmp := New(dstT.Elem()).Elem()
			err := Convert(v, tmp)
			if err != nil {
				sliceErrs.Add(err.Subjectf("[%d]", j))
				continue
			}
			newSlice.Index(i).Set(tmp)
			i++
		}
		if err := sliceErrs.Error(); err != nil {
			return err
		}
		dst.Set(newSlice)
		return nil
	}
	return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}

func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
	convertible = true
	dstT := dst.Type()
	if dst.Kind() == reflect.Ptr {
		if dst.IsNil() {
			dst.Set(New(dstT.Elem()))
		}
		dst = dst.Elem()
		dstT = dst.Type()
	}
	if dst.Kind() == reflect.String {
		dst.SetString(src)
		return
	}
	switch dstT {
	case reflect.TypeFor[time.Duration]():
		if src == "" {
			dst.Set(reflect.Zero(dstT))
			return
		}
		d, err := time.ParseDuration(src)
		if err != nil {
			return true, gperr.Wrap(err)
		}
		dst.Set(reflect.ValueOf(d))
		return
	default:
	}
	if dstKind := dst.Kind(); isIntFloat(dstKind) {
		var i any
		var err error
		switch {
		case dstKind == reflect.Bool:
			i, err = strconv.ParseBool(src)
		case dst.CanInt():
			i, err = strconv.ParseInt(src, 10, dstT.Bits())
		case dst.CanUint():
			i, err = strconv.ParseUint(src, 10, dstT.Bits())
		case dst.CanFloat():
			i, err = strconv.ParseFloat(src, dstT.Bits())
		}
		if err != nil {
			return true, gperr.Wrap(err)
		}
		dst.Set(reflect.ValueOf(i).Convert(dstT))
		return
	}
	// check if target implements string parser
	if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
		return true, gperr.Wrap(parser.Parse(src))
	}
	// yaml like
	var tmp any
	switch dst.Kind() {
	case reflect.Slice:
		src = strings.TrimSpace(src)
		isMultiline := strings.ContainsRune(src, '\n')
		// treats one liner without leading dash as comma separated list
		if !isMultiline && src[0] != '-' {
			values := strutils.CommaSeperatedList(src)
			dst.Set(reflect.MakeSlice(dst.Type(), len(values), len(values)))
			errs := gperr.NewBuilder("invalid slice values")
			for i, v := range values {
				err := Convert(reflect.ValueOf(v), dst.Index(i))
				if err != nil {
					errs.Add(err.Subjectf("[%d]", i))
				}
			}
			if errs.HasError() {
				return true, errs.Error()
			}
			return
		}
		sl := make([]any, 0)
		err := yaml.Unmarshal([]byte(src), &sl)
		if err != nil {
			return true, gperr.Wrap(err)
		}
		tmp = sl
	case reflect.Map, reflect.Struct:
		rawMap := make(SerializedObject)
		err := yaml.Unmarshal([]byte(src), &rawMap)
		if err != nil {
			return true, gperr.Wrap(err)
		}
		tmp = rawMap
	default:
		return false, nil
	}
	return true, Convert(reflect.ValueOf(tmp), dst)
}

func UnmarshalValidateYAML[T any](data []byte, target *T) gperr.Error {
	m := make(map[string]any)
	if err := yaml.Unmarshal(data, m); err != nil {
		return gperr.Wrap(err)
	}
	return MapUnmarshalValidate(m, target)
}

func UnmarshalValidateYAMLMap[V any](data []byte) (_ functional.Map[string, V], err gperr.Error) {
	m := make(map[string]any)
	if err = gperr.Wrap(yaml.Unmarshal(data, m)); err != nil {
		return
	}
	m2 := make(map[string]V, len(m))
	if err = MapUnmarshalValidate(m, m2); err != nil {
		return
	}
	return functional.NewMapFrom(m2), nil
}

func loadSerialized[T any](path string, dst *T, deserialize func(data []byte, dst any) error) error {
	data, err := os.ReadFile(path)
	if err != nil {
		return err
	}
	return deserialize(data, dst)
}

func SaveJSON[T any](path string, src *T, perm os.FileMode) error {
	data, err := json.Marshal(src)
	if err != nil {
		return err
	}
	return os.WriteFile(path, data, perm)
}

func LoadJSONIfExist[T any](path string, dst *T) error {
	err := loadSerialized(path, dst, json.Unmarshal)
	if os.IsNotExist(err) {
		return nil
	}
	return err
}