From eb7495b02a067828ea859cc2496490e25b0d39e5 Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 14 Apr 2025 06:32:16 +0800 Subject: [PATCH] fix: unmarshal --- internal/utils/serialization.go | 51 ++++++++++++++++------------ internal/utils/serialization_test.go | 12 +++---- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index c89c51a..995512e 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -308,6 +308,20 @@ func isIntFloat(t reflect.Kind) bool { return t >= reflect.Bool && t <= reflect.Float64 } +func itoa(v reflect.Value) string { + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(v.Int(), 10) + case reflect.Bool: + return strconv.FormatBool(v.Bool()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(v.Uint(), 10) + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(v.Float(), 'f', -1, 64) + } + panic("invalid call on itoa") +} + // Convert attempts to convert the src to dst. // // If src is a map, it is deserialized into dst. @@ -365,20 +379,12 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error { } case isIntFloat(srcKind): if dst.Kind() == reflect.String { - var strV string - switch { - case src.CanInt(): - strV = strconv.FormatInt(src.Int(), 10) - case srcKind == reflect.Bool: - strV = strconv.FormatBool(src.Bool()) - case src.CanUint(): - strV = strconv.FormatUint(src.Uint(), 10) - case src.CanFloat(): - strV = strconv.FormatFloat(src.Float(), 'f', -1, 64) - } - dst.Set(reflect.ValueOf(strV)) + dst.Set(reflect.ValueOf(itoa(src))) return nil } + if dst.Addr().Type().Implements(typeStrParser) { + return Convert(reflect.ValueOf(itoa(src)), dst) + } if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) { return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT) } @@ -445,7 +451,8 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe dst = dst.Elem() dstT = dst.Type() } - if dst.Kind() == reflect.String { + dstKind := dst.Kind() + if dstKind == reflect.String { dst.SetString(src) return } @@ -482,7 +489,7 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst) return } - if dstKind := dst.Kind(); isIntFloat(dstKind) { + if isIntFloat(dstKind) { var i any var err error switch { @@ -512,7 +519,14 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe } // yaml like var tmp any - switch dst.Kind() { + switch dstKind { + case reflect.Map, reflect.Struct: + rawMap := make(SerializedObject) + err := yaml.Unmarshal([]byte(src), &rawMap) + if err != nil { + return true, gperr.Wrap(err) + } + tmp = rawMap case reflect.Slice: src = strings.TrimSpace(src) isMultiline := strings.ContainsRune(src, '\n') @@ -538,13 +552,6 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe return true, gperr.Wrap(err) } tmp = sl - case reflect.Map, reflect.Struct: - rawMap := make(SerializedObject) - err := yaml.Unmarshal([]byte(src), &rawMap) - if err != nil { - return true, gperr.Wrap(err) - } - tmp = rawMap default: return false, nil } diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 7e0ea5d..4c67c75 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -7,9 +7,7 @@ import ( "reflect" "strconv" "testing" - "time" - "github.com/yusing/go-proxy/internal/utils/strutils" . "github.com/yusing/go-proxy/internal/utils/testing" "gopkg.in/yaml.v3" ) @@ -95,12 +93,12 @@ func TestStringIntConvert(t *testing.T) { field := refl.Elem().Field(i) t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) { ok, err := ConvertString("127", field) - ExpectTrue(t, ok) - ExpectNoError(t, err) + ExpectTrue(t, ok) + ExpectNoError(t, err) ExpectEqualValues(t, field.Interface(), 127) err = Convert(reflect.ValueOf(uint8(64)), field) - ExpectNoError(t, err) + ExpectNoError(t, err) ExpectEqualValues(t, field.Interface(), 64) }) } @@ -212,8 +210,8 @@ func BenchmarkStringToMapYAML(b *testing.B) { func TestStringToStruct(t *testing.T) { type T struct { - A string - B int + A string + B int } t.Run("yaml-like simple", func(t *testing.T) { var dst T