From 3f2dfe14b5401a1315bf7ab71c11721583d13a1a Mon Sep 17 00:00:00 2001 From: yusing Date: Sun, 13 Apr 2025 12:24:11 +0800 Subject: [PATCH] fix: unmarshal and some tests --- internal/gperr/subject.go | 3 +- internal/net/gphttp/accesslog/filter.go | 10 +- internal/net/gphttp/accesslog/filter_test.go | 8 +- internal/route/types/http_config_test.go | 1 + internal/utils/serialization.go | 79 ++++++++------- internal/utils/serialization_test.go | 100 ++++++++----------- 6 files changed, 100 insertions(+), 101 deletions(-) diff --git a/internal/gperr/subject.go b/internal/gperr/subject.go index 6bbee63..79ae475 100644 --- a/internal/gperr/subject.go +++ b/internal/gperr/subject.go @@ -2,6 +2,7 @@ package gperr import ( "encoding/json" + "errors" "slices" "strings" @@ -64,7 +65,7 @@ func (err *withSubject) Prepend(subject string) *withSubject { } func (err *withSubject) Is(other error) bool { - return err.Err == other + return errors.Is(other, err.Err) } func (err *withSubject) Unwrap() error { diff --git a/internal/net/gphttp/accesslog/filter.go b/internal/net/gphttp/accesslog/filter.go index a491113..dad1401 100644 --- a/internal/net/gphttp/accesslog/filter.go +++ b/internal/net/gphttp/accesslog/filter.go @@ -23,7 +23,7 @@ type ( Key, Value string } Host string - CIDR struct{ net.IPNet } + CIDR net.IPNet ) var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter") @@ -85,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool { return req.Host == string(h) } -func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool { +func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool { ip, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { ip = req.RemoteAddr @@ -94,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool { if netIP == nil { return false } - return cidr.Contains(netIP) + return (*net.IPNet)(cidr).Contains(netIP) +} + +func (cidr *CIDR) String() string { + return (*net.IPNet)(cidr).String() } diff --git a/internal/net/gphttp/accesslog/filter_test.go b/internal/net/gphttp/accesslog/filter_test.go index 306f498..5d8e8c8 100644 --- a/internal/net/gphttp/accesslog/filter_test.go +++ b/internal/net/gphttp/accesslog/filter_test.go @@ -157,11 +157,9 @@ func TestHeaderFilter(t *testing.T) { func TestCIDRFilter(t *testing.T) { cidr := []*CIDR{{ - net.IPNet{ - IP: net.ParseIP("192.168.10.0"), - Mask: net.CIDRMask(24, 32), - }}, - } + IP: net.ParseIP("192.168.10.0"), + Mask: net.CIDRMask(24, 32), + }} ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") inCIDR := &http.Request{ RemoteAddr: "192.168.10.1", diff --git a/internal/route/types/http_config_test.go b/internal/route/types/http_config_test.go index 2846212..3040818 100644 --- a/internal/route/types/http_config_test.go +++ b/internal/route/types/http_config_test.go @@ -39,6 +39,7 @@ func TestHTTPConfigDeserialize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := Route{} + tt.input["host"] = "internal" err := utils.MapUnmarshalValidate(tt.input, &cfg) if err != nil { ExpectNoError(t, err) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 59d7293..c89c51a 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -47,11 +47,15 @@ var ( var ( typeDuration = reflect.TypeFor[time.Duration]() + typeTime = reflect.TypeFor[time.Time]() typeURL = reflect.TypeFor[url.URL]() - typeCIDR = reflect.TypeFor[*net.IPNet]() + typeCIDR = reflect.TypeFor[net.IPNet]() typeMapMarshaller = reflect.TypeFor[MapMarshaller]() typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]() + typeJSONMarshaller = reflect.TypeFor[json.Marshaler]() + + typeAny = reflect.TypeOf((*any)(nil)).Elem() ) var defaultValues = functional.NewMapOf[reflect.Type, func() any]() @@ -360,20 +364,26 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error { return err } case isIntFloat(srcKind): - var strV string - switch { - case src.CanInt(): - strV = strconv.FormatInt(src.Int(), 10) - case srcKind == reflect.Bool: - strV = strconv.FormatBool(src.Bool()) - case src.CanUint(): - strV = strconv.FormatUint(src.Uint(), 10) - case src.CanFloat(): - strV = strconv.FormatFloat(src.Float(), 'f', -1, 64) + if dst.Kind() == reflect.String { + var strV string + switch { + case src.CanInt(): + strV = strconv.FormatInt(src.Int(), 10) + case srcKind == reflect.Bool: + strV = strconv.FormatBool(src.Bool()) + case src.CanUint(): + strV = strconv.FormatUint(src.Uint(), 10) + case src.CanFloat(): + strV = strconv.FormatFloat(src.Float(), 'f', -1, 64) + } + dst.Set(reflect.ValueOf(strV)) + return nil } - if convertible, err := ConvertString(strV, dst); convertible { - return err + if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) { + return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT) } + dst.Set(src.Convert(dstT)) + return nil case srcKind == reflect.Map: if src.Len() == 0 { return nil @@ -412,8 +422,17 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error { return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT) } -func nilPointer[T any]() reflect.Value { - return reflect.ValueOf((*T)(nil)) +func isSameOrEmbededType(src, dst reflect.Type) bool { + return src == dst || src.ConvertibleTo(dst) +} + +func setSameOrEmbedddType(src, dst reflect.Value) { + dstT := dst.Type() + if src.Type().AssignableTo(dstT) { + dst.Set(src) + } else { + dst.Set(src.Convert(dstT)) + } } func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) { @@ -430,12 +449,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe dst.SetString(src) return } - switch dstT { - case typeDuration: - if src == "" { - dst.Set(reflect.Zero(dstT)) - return false, nil - } + if src == "" { + dst.Set(reflect.Zero(dstT)) + return + } + switch { + case dstT == typeDuration: d, err := time.ParseDuration(src) if err != nil { return true, gperr.Wrap(err) @@ -445,30 +464,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe } dst.Set(reflect.ValueOf(d)) return - case typeURL: - if src == "" { - dst.Addr().Set(nilPointer[*url.URL]()) - return - } + case isSameOrEmbededType(dstT, typeURL): u, err := url.Parse(src) if err != nil { return true, gperr.Wrap(err) } - dst.Set(reflect.ValueOf(u).Elem()) + setSameOrEmbedddType(reflect.ValueOf(u).Elem(), dst) return - case typeCIDR: - if src == "" { - dst.Addr().Set(nilPointer[*net.IPNet]()) - return - } - if !strings.Contains(src, "/") { + case isSameOrEmbededType(dstT, typeCIDR): + if !strings.ContainsRune(src, '/') { src += "/32" // single IP } _, ipnet, err := net.ParseCIDR(src) if err != nil { return true, gperr.Wrap(err) } - dst.Set(reflect.ValueOf(ipnet).Elem()) + setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst) return } if dstKind := dst.Kind(); isIntFloat(dstKind) { diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 86a8dc6..7e0ea5d 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -1,6 +1,8 @@ package utils import ( + "fmt" + "net" "net/url" "reflect" "strconv" @@ -12,7 +14,7 @@ import ( "gopkg.in/yaml.v3" ) -func TestDeserialize(t *testing.T) { +func TestUnmarshal(t *testing.T) { type S struct { I int S string @@ -41,7 +43,7 @@ func TestDeserialize(t *testing.T) { } ) - t.Run("deserialize", func(t *testing.T) { + t.Run("unmarshal", func(t *testing.T) { var s2 S err := MapUnmarshalValidate(testStructSerialized, &s2) ExpectNoError(t, err) @@ -49,7 +51,7 @@ func TestDeserialize(t *testing.T) { }) } -func TestDeserializeAnonymousField(t *testing.T) { +func TestUnmarshalAnonymousField(t *testing.T) { type Anon struct { A, B int } @@ -77,59 +79,31 @@ func TestDeserializeAnonymousField(t *testing.T) { } func TestStringIntConvert(t *testing.T) { - s := "127" - test := struct { - i8 int8 - i16 int16 - i32 int32 - i64 int64 - u8 uint8 - u16 uint16 - u32 uint32 - u64 uint64 + I8 int8 + I16 int16 + I32 int32 + I64 int64 + U8 uint8 + U16 uint16 + U32 uint32 + U64 uint64 }{} - ok, err := ConvertString(s, reflect.ValueOf(&test.i8)) - + refl := reflect.ValueOf(&test) + for i := range refl.Elem().NumField() { + field := refl.Elem().Field(i) + t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) { + ok, err := ConvertString("127", field) ExpectTrue(t, ok) ExpectNoError(t, err) - ExpectEqualValues(t, test.i8, int8(127)) + ExpectEqualValues(t, field.Interface(), 127) - ok, err = ConvertString(s, reflect.ValueOf(&test.i16)) - ExpectTrue(t, ok) + err = Convert(reflect.ValueOf(uint8(64)), field) ExpectNoError(t, err) - ExpectEqualValues(t, test.i16, int16(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.i32)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.i32, int32(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.i64)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.i64, int64(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.u8)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.u8, uint8(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.u16)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.u16, uint16(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.u32)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.u32, uint32(127)) - - ok, err = ConvertString(s, reflect.ValueOf(&test.u64)) - ExpectTrue(t, ok) - ExpectNoError(t, err) - ExpectEqualValues(t, test.u64, uint64(127)) + ExpectEqualValues(t, field.Interface(), 64) + }) + } } type testModel struct { @@ -164,8 +138,8 @@ func TestConvertor(t *testing.T) { ExpectEqualValues(t, m.Test.foo, 123) ExpectEqualValues(t, m.Test.bar, "123") - ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m)) - ExpectEqualValues(t, m.Baz, "123") + ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m)) + ExpectEqualValues(t, m.Baz, "456") }) t.Run("invalid", func(t *testing.T) { @@ -237,18 +211,28 @@ func BenchmarkStringToMapYAML(b *testing.B) { } func TestStringToStruct(t *testing.T) { - t.Run("yaml-like", func(t *testing.T) { - dst := struct { + type T struct { A string B int - }{} + } + t.Run("yaml-like simple", func(t *testing.T) { + var dst T convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst)) ExpectTrue(t, convertible) ExpectNoError(t, err) - ExpectEqualValues(t, dst, struct { - A string - B int - }{"a", 123}) + ExpectEqualValues(t, dst.A, "a") + ExpectEqualValues(t, dst.B, 123) + }) + + type T2 struct { + URL *url.URL + CIDR *net.IPNet + } + t.Run("yaml-like complex", func(t *testing.T) { + var dst T2 + convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst)) + ExpectTrue(t, convertible) + ExpectNoError(t, err) }) }