GoDoxy/internal/utils/serialization.go
2025-01-08 13:50:34 +08:00

427 lines
11 KiB
Go

package utils
import (
"encoding/json"
"errors"
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"
"unicode"
"github.com/go-playground/validator/v10"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils"
"gopkg.in/yaml.v3"
)
type SerializedObject = map[string]any
var (
ErrInvalidType = E.New("invalid type")
ErrNilValue = E.New("nil")
ErrUnsettable = E.New("unsettable")
ErrUnsupportedConversion = E.New("unsupported conversion")
ErrUnknownField = E.New("unknown field")
)
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
func RegisterDefaultValueFactory[T any](factory func() *T) {
t := reflect.TypeFor[T]()
if t.Kind() == reflect.Ptr {
panic("pointer of pointer")
}
if defaultValues.Has(t) {
panic("default value for " + t.String() + " already registered")
}
defaultValues.Store(t, func() any { return factory() })
}
func New(t reflect.Type) reflect.Value {
if dv, ok := defaultValues.Load(t); ok {
return reflect.ValueOf(dv())
}
return reflect.New(t)
}
func extractFields(t reflect.Type) []reflect.StructField {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil
}
var fields []reflect.StructField
for i := range t.NumField() {
field := t.Field(i)
if !field.IsExported() {
continue
}
if field.Anonymous {
fields = append(fields, extractFields(field.Type)...)
} else {
fields = append(fields, field)
}
}
return fields
}
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields and validate if needed.
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.Error {
if src == nil {
return E.Errorf("deserialize: src is %w", ErrNilValue)
}
if dst == nil {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
}
dstV := reflect.ValueOf(dst)
dstT := dstV.Type()
for dstT.Kind() == reflect.Ptr {
if dstV.IsNil() {
if dstV.CanSet() {
dstV.Set(New(dstT.Elem()))
} else {
return E.Errorf("deserialize: dst is %w\n%s", ErrNilValue, debug.Stack())
}
}
dstV = dstV.Elem()
dstT = dstV.Type()
}
// convert data fields to lower no-snake
// convert target fields to lower no-snake
// then check if the field of data is in the target
errs := E.NewBuilder("deserialize error")
switch dstV.Kind() {
case reflect.Struct:
needValidate := false
mapping := make(map[string]reflect.Value)
fieldName := make(map[string]string)
fields := extractFields(dstT)
for _, field := range fields {
var key string
if jsonTag, ok := field.Tag.Lookup("json"); ok {
if jsonTag == "-" {
continue
}
key = strutils.CommaSeperatedList(jsonTag)[0]
} else {
key = field.Name
}
key = strutils.ToLowerNoSnake(key)
mapping[key] = dstV.FieldByName(field.Name)
fieldName[field.Name] = key
_, needValidate = field.Tag.Lookup("validate")
aliases, ok := field.Tag.Lookup("aliases")
if ok {
for _, alias := range strutils.CommaSeperatedList(aliases) {
mapping[alias] = dstV.FieldByName(field.Name)
fieldName[field.Name] = alias
}
}
}
for k, v := range src {
if field, ok := mapping[strutils.ToLowerNoSnake(k)]; ok {
err := Convert(reflect.ValueOf(v), field)
if err != nil {
errs.Add(err.Subject(k))
}
} else {
errs.Add(ErrUnknownField.Subject(k).Withf(strutils.DoYouMean(NearestField(k, mapping))))
}
}
if needValidate {
err := validate.Struct(dstV.Interface())
var valErrs validator.ValidationErrors
if errors.As(err, &valErrs) {
for _, e := range valErrs {
detail := e.ActualTag()
if e.Param() != "" {
detail += ":" + e.Param()
}
errs.Add(ErrValidationError.
Subject(e.StructNamespace()).
Withf("require %q", detail))
}
}
}
return errs.Error()
case reflect.Map:
if dstV.IsNil() {
dstV.Set(reflect.MakeMap(dstT))
}
for k := range src {
mapVT := dstT.Elem()
tmp := New(mapVT).Elem()
err := Convert(reflect.ValueOf(src[k]), tmp)
if err == nil {
dstV.SetMapIndex(reflect.ValueOf(k), tmp)
} else {
errs.Add(err.Subject(k))
}
}
return errs.Error()
default:
return ErrUnsupportedConversion.Subject("deserialize to " + dstT.String())
}
}
func isIntFloat(t reflect.Kind) bool {
return t >= reflect.Bool && t <= reflect.Float64
}
// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
// If src is a slice, each of its elements are converted and stored in dst.
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
//
// If dst is not settable, an error is returned.
// If src cannot be converted to dst, an error is returned.
// If any error occurs during conversion (e.g. deserialization), it is returned.
//
// Returns:
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.Error {
if !dst.IsValid() {
return E.Errorf("convert: dst is %w", ErrNilValue)
}
if !src.IsValid() {
if dst.CanSet() {
dst.Set(reflect.Zero(dst.Type()))
return nil
}
return E.Errorf("convert: src is %w", ErrNilValue)
}
srcT := src.Type()
dstT := dst.Type()
if src.Kind() == reflect.Interface {
src = src.Elem()
srcT = src.Type()
}
if !dst.CanSet() {
return ErrUnsettable.Subject(dstT.String())
}
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
srcKind := srcT.Kind()
switch {
case srcT.AssignableTo(dstT):
dst.Set(src)
return nil
// case srcT.ConvertibleTo(dstT):
// dst.Set(src.Convert(dstT))
// return nil
case srcKind == reflect.String:
if convertible, err := ConvertString(src.String(), dst); convertible {
return err
}
case isIntFloat(srcKind):
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
if convertible, err := ConvertString(strV, dst); convertible {
return err
}
case srcKind == reflect.Map:
if src.Len() == 0 {
return nil
}
obj, ok := src.Interface().(SerializedObject)
if !ok {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
return Deserialize(obj, dst.Addr().Interface())
case srcKind == reflect.Slice:
if src.Len() == 0 {
return nil
}
if dstT.Kind() != reflect.Slice {
return ErrUnsupportedConversion.Subject(dstT.String() + " to " + srcT.String())
}
newSlice := reflect.MakeSlice(dstT, 0, src.Len())
i := 0
for _, v := range src.Seq2() {
tmp := New(dstT.Elem()).Elem()
err := Convert(v, tmp)
if err != nil {
return err.Subjectf("[%d]", i)
}
newSlice = reflect.Append(newSlice, tmp)
i++
}
dst.Set(newSlice)
return nil
}
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.Error) {
convertible = true
dstT := dst.Type()
if dst.Kind() == reflect.Ptr {
if dst.IsNil() {
dst.Set(New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
if dst.Kind() == reflect.String {
dst.SetString(src)
return
}
switch dstT {
case reflect.TypeFor[time.Duration]():
if src == "" {
dst.Set(reflect.Zero(dstT))
return
}
d, err := time.ParseDuration(src)
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(d))
return
default:
}
if dstKind := dst.Kind(); isIntFloat(dstKind) {
var i any
var err error
switch {
case dstKind == reflect.Bool:
i, err = strconv.ParseBool(src)
case dst.CanInt():
i, err = strconv.ParseInt(src, 10, dstT.Bits())
case dst.CanUint():
i, err = strconv.ParseUint(src, 10, dstT.Bits())
case dst.CanFloat():
i, err = strconv.ParseFloat(src, dstT.Bits())
}
if err != nil {
return true, E.From(err)
}
dst.Set(reflect.ValueOf(i).Convert(dstT))
return
}
// check if (*T).Convertor is implemented
if parser, ok := dst.Addr().Interface().(strutils.Parser); ok {
return true, E.From(parser.Parse(src))
}
// yaml like
isMultiline := strings.ContainsRune(src, '\n')
var tmp any
switch dst.Kind() {
case reflect.Slice:
// one liner is comma separated list
if !isMultiline {
values := strutils.CommaSeperatedList(src)
dst.Set(reflect.MakeSlice(dst.Type(), len(values), len(values)))
errs := E.NewBuilder("invalid slice values")
for i, v := range values {
err := Convert(reflect.ValueOf(v), dst.Index(i))
if err != nil {
errs.Add(err.Subjectf("[%d]", i))
}
}
if errs.HasError() {
return true, errs.Error()
}
return
}
lines := strutils.SplitLine(src)
sl := make([]string, 0, len(lines))
for _, line := range lines {
line = strings.TrimLeftFunc(line, func(r rune) bool {
return r == '-' || unicode.IsSpace(r)
})
if line == "" || line[0] == '#' {
continue
}
sl = append(sl, line)
}
tmp = sl
case reflect.Map, reflect.Struct:
rawMap := make(SerializedObject)
err := yaml.Unmarshal([]byte(src), &rawMap)
if err != nil {
return true, E.From(err)
}
tmp = rawMap
default:
return false, nil
}
return true, Convert(reflect.ValueOf(tmp), dst)
}
func DeserializeYAML[T any](data []byte, target T) E.Error {
m := make(map[string]any)
if err := yaml.Unmarshal(data, m); err != nil {
return E.From(err)
}
return Deserialize(m, target)
}
func DeserializeYAMLMap[V any](data []byte) (_ functional.Map[string, V], err E.Error) {
m := make(map[string]any)
if err = E.From(yaml.Unmarshal(data, m)); err != nil {
return
}
m2 := make(map[string]V, len(m))
if err = Deserialize(m, m2); err != nil {
return
}
return functional.NewMapFrom(m2), nil
}
func LoadJSON[T any](path string, dst *T) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, dst)
}
func SaveJSON[T any](path string, src *T, perm os.FileMode) error {
data, err := json.Marshal(src)
if err != nil {
return err
}
return os.WriteFile(path, data, perm)
}