From 8e2788623530f5df71c139db533e0d1570cebd9f Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 14 May 2025 12:20:52 +0800 Subject: [PATCH] fix: incorrect unmarshal behavior for pointer primitives --- internal/utils/serialization.go | 24 +++++++++--- internal/utils/serialization_test.go | 55 ++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index b2930d5..e540d23 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -314,12 +314,26 @@ func Convert(src reflect.Value, dst reflect.Value, checkValidateTag bool) gperr. return gperr.Errorf("convert: dst is %w", ErrNilValue) } - if !src.IsValid() || src.IsZero() { - if dst.CanSet() { - dst.Set(reflect.Zero(dst.Type())) - return nil + if (src.Kind() == reflect.Pointer && src.IsNil()) || !src.IsValid() { + if !dst.CanSet() { + return gperr.Errorf("convert: src is %w", ErrNilValue) } - return gperr.Errorf("convert: src is %w", ErrNilValue) + // manually set nil + dst.Set(reflect.Zero(dst.Type())) + return nil + } + + if src.IsZero() { + if !dst.CanSet() { + return gperr.Errorf("convert: src is %w", ErrNilValue) + } + switch dst.Kind() { + case reflect.Pointer, reflect.Interface: + dst.Set(reflect.New(dst.Type().Elem())) + default: + dst.Set(reflect.Zero(dst.Type())) + } + return nil } srcT := src.Type() diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 2350deb..6a08779 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -73,6 +73,61 @@ func TestDeserializeAnonymousField(t *testing.T) { ExpectEqual(t, s2.C, 3) } +func TestPointerPrimitives(t *testing.T) { + type testType struct { + B *bool `json:"b"` + I8 *int8 `json:"i8"` + I16 *int16 `json:"i16"` + I32 *int32 `json:"i32"` + I64 *int64 `json:"i64"` + U8 *uint8 `json:"u8"` + U16 *uint16 `json:"u16"` + U32 *uint32 `json:"u32"` + U64 *uint64 `json:"u64"` + } + var test testType + + err := MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test) + ExpectNoError(t, err) + ExpectEqual(t, *test.B, true) + ExpectEqual(t, *test.I8, int8(127)) + ExpectEqual(t, *test.I16, int16(127)) + ExpectEqual(t, *test.I32, int32(127)) + ExpectEqual(t, *test.I64, int64(127)) + ExpectEqual(t, *test.U8, uint8(127)) + ExpectEqual(t, *test.U16, uint16(127)) + ExpectEqual(t, *test.U32, uint32(127)) + ExpectEqual(t, *test.U64, uint64(127)) + + // zero values + err = MapUnmarshalValidate(map[string]any{"b": false, "i8": int8(0), "i16": int16(0), "i32": int32(0), "i64": int64(0), "u8": uint8(0), "u16": uint16(0), "u32": uint32(0), "u64": uint64(0)}, &test) + ExpectNoError(t, err) + ExpectEqual(t, *test.B, false) + ExpectEqual(t, *test.I8, int8(0)) + ExpectEqual(t, *test.I16, int16(0)) + ExpectEqual(t, *test.I32, int32(0)) + ExpectEqual(t, *test.I64, int64(0)) + ExpectEqual(t, *test.U8, uint8(0)) + ExpectEqual(t, *test.U16, uint16(0)) + ExpectEqual(t, *test.U32, uint32(0)) + ExpectEqual(t, *test.U64, uint64(0)) + + // nil values + err = MapUnmarshalValidate(map[string]any{"b": true, "i8": int8(127), "i16": int16(127), "i32": int32(127), "i64": int64(127), "u8": uint8(127), "u16": uint16(127), "u32": uint32(127), "u64": uint64(127)}, &test) + ExpectNoError(t, err) + err = MapUnmarshalValidate(map[string]any{"b": nil, "i8": nil, "i16": nil, "i32": nil, "i64": nil, "u8": nil, "u16": nil, "u32": nil, "u64": nil}, &test) + ExpectNoError(t, err) + ExpectEqual(t, test.B, nil) + ExpectEqual(t, test.I8, nil) + ExpectEqual(t, test.I16, nil) + ExpectEqual(t, test.I32, nil) + ExpectEqual(t, test.I64, nil) + ExpectEqual(t, test.U8, nil) + ExpectEqual(t, test.U16, nil) + ExpectEqual(t, test.U32, nil) + ExpectEqual(t, test.U64, nil) +} + func TestStringIntConvert(t *testing.T) { s := "127"