fix: unmarshal and some tests

This commit is contained in:
yusing 2025-04-13 12:24:11 +08:00
parent be87d47ebb
commit 3f2dfe14b5
6 changed files with 100 additions and 101 deletions

View file

@ -2,6 +2,7 @@ package gperr
import ( import (
"encoding/json" "encoding/json"
"errors"
"slices" "slices"
"strings" "strings"
@ -64,7 +65,7 @@ func (err *withSubject) Prepend(subject string) *withSubject {
} }
func (err *withSubject) Is(other error) bool { func (err *withSubject) Is(other error) bool {
return err.Err == other return errors.Is(other, err.Err)
} }
func (err *withSubject) Unwrap() error { func (err *withSubject) Unwrap() error {

View file

@ -23,7 +23,7 @@ type (
Key, Value string Key, Value string
} }
Host string Host string
CIDR struct{ net.IPNet } CIDR net.IPNet
) )
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter") var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
@ -85,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h) return req.Host == string(h)
} }
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool { func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr) ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil { if err != nil {
ip = req.RemoteAddr ip = req.RemoteAddr
@ -94,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
if netIP == nil { if netIP == nil {
return false return false
} }
return cidr.Contains(netIP) return (*net.IPNet)(cidr).Contains(netIP)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
} }

View file

@ -157,11 +157,9 @@ func TestHeaderFilter(t *testing.T) {
func TestCIDRFilter(t *testing.T) { func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{{ cidr := []*CIDR{{
net.IPNet{ IP: net.ParseIP("192.168.10.0"),
IP: net.ParseIP("192.168.10.0"), Mask: net.CIDRMask(24, 32),
Mask: net.CIDRMask(24, 32), }}
}},
}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24") ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{ inCIDR := &http.Request{
RemoteAddr: "192.168.10.1", RemoteAddr: "192.168.10.1",

View file

@ -39,6 +39,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cfg := Route{} cfg := Route{}
tt.input["host"] = "internal"
err := utils.MapUnmarshalValidate(tt.input, &cfg) err := utils.MapUnmarshalValidate(tt.input, &cfg)
if err != nil { if err != nil {
ExpectNoError(t, err) ExpectNoError(t, err)

View file

@ -47,11 +47,15 @@ var (
var ( var (
typeDuration = reflect.TypeFor[time.Duration]() typeDuration = reflect.TypeFor[time.Duration]()
typeTime = reflect.TypeFor[time.Time]()
typeURL = reflect.TypeFor[url.URL]() typeURL = reflect.TypeFor[url.URL]()
typeCIDR = reflect.TypeFor[*net.IPNet]() typeCIDR = reflect.TypeFor[net.IPNet]()
typeMapMarshaller = reflect.TypeFor[MapMarshaller]() typeMapMarshaller = reflect.TypeFor[MapMarshaller]()
typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]() typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]()
typeJSONMarshaller = reflect.TypeFor[json.Marshaler]()
typeAny = reflect.TypeOf((*any)(nil)).Elem()
) )
var defaultValues = functional.NewMapOf[reflect.Type, func() any]() var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
@ -360,20 +364,26 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return err return err
} }
case isIntFloat(srcKind): case isIntFloat(srcKind):
var strV string if dst.Kind() == reflect.String {
switch { var strV string
case src.CanInt(): switch {
strV = strconv.FormatInt(src.Int(), 10) case src.CanInt():
case srcKind == reflect.Bool: strV = strconv.FormatInt(src.Int(), 10)
strV = strconv.FormatBool(src.Bool()) case srcKind == reflect.Bool:
case src.CanUint(): strV = strconv.FormatBool(src.Bool())
strV = strconv.FormatUint(src.Uint(), 10) case src.CanUint():
case src.CanFloat(): strV = strconv.FormatUint(src.Uint(), 10)
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64) case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
dst.Set(reflect.ValueOf(strV))
return nil
} }
if convertible, err := ConvertString(strV, dst); convertible { if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) {
return err return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
} }
dst.Set(src.Convert(dstT))
return nil
case srcKind == reflect.Map: case srcKind == reflect.Map:
if src.Len() == 0 { if src.Len() == 0 {
return nil return nil
@ -412,8 +422,17 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT) return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
} }
func nilPointer[T any]() reflect.Value { func isSameOrEmbededType(src, dst reflect.Type) bool {
return reflect.ValueOf((*T)(nil)) return src == dst || src.ConvertibleTo(dst)
}
func setSameOrEmbedddType(src, dst reflect.Value) {
dstT := dst.Type()
if src.Type().AssignableTo(dstT) {
dst.Set(src)
} else {
dst.Set(src.Convert(dstT))
}
} }
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) { func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
@ -430,12 +449,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
dst.SetString(src) dst.SetString(src)
return return
} }
switch dstT { if src == "" {
case typeDuration: dst.Set(reflect.Zero(dstT))
if src == "" { return
dst.Set(reflect.Zero(dstT)) }
return false, nil switch {
} case dstT == typeDuration:
d, err := time.ParseDuration(src) d, err := time.ParseDuration(src)
if err != nil { if err != nil {
return true, gperr.Wrap(err) return true, gperr.Wrap(err)
@ -445,30 +464,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
} }
dst.Set(reflect.ValueOf(d)) dst.Set(reflect.ValueOf(d))
return return
case typeURL: case isSameOrEmbededType(dstT, typeURL):
if src == "" {
dst.Addr().Set(nilPointer[*url.URL]())
return
}
u, err := url.Parse(src) u, err := url.Parse(src)
if err != nil { if err != nil {
return true, gperr.Wrap(err) return true, gperr.Wrap(err)
} }
dst.Set(reflect.ValueOf(u).Elem()) setSameOrEmbedddType(reflect.ValueOf(u).Elem(), dst)
return return
case typeCIDR: case isSameOrEmbededType(dstT, typeCIDR):
if src == "" { if !strings.ContainsRune(src, '/') {
dst.Addr().Set(nilPointer[*net.IPNet]())
return
}
if !strings.Contains(src, "/") {
src += "/32" // single IP src += "/32" // single IP
} }
_, ipnet, err := net.ParseCIDR(src) _, ipnet, err := net.ParseCIDR(src)
if err != nil { if err != nil {
return true, gperr.Wrap(err) return true, gperr.Wrap(err)
} }
dst.Set(reflect.ValueOf(ipnet).Elem()) setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst)
return return
} }
if dstKind := dst.Kind(); isIntFloat(dstKind) { if dstKind := dst.Kind(); isIntFloat(dstKind) {

View file

@ -1,6 +1,8 @@
package utils package utils
import ( import (
"fmt"
"net"
"net/url" "net/url"
"reflect" "reflect"
"strconv" "strconv"
@ -12,7 +14,7 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
func TestDeserialize(t *testing.T) { func TestUnmarshal(t *testing.T) {
type S struct { type S struct {
I int I int
S string S string
@ -41,7 +43,7 @@ func TestDeserialize(t *testing.T) {
} }
) )
t.Run("deserialize", func(t *testing.T) { t.Run("unmarshal", func(t *testing.T) {
var s2 S var s2 S
err := MapUnmarshalValidate(testStructSerialized, &s2) err := MapUnmarshalValidate(testStructSerialized, &s2)
ExpectNoError(t, err) ExpectNoError(t, err)
@ -49,7 +51,7 @@ func TestDeserialize(t *testing.T) {
}) })
} }
func TestDeserializeAnonymousField(t *testing.T) { func TestUnmarshalAnonymousField(t *testing.T) {
type Anon struct { type Anon struct {
A, B int A, B int
} }
@ -77,59 +79,31 @@ func TestDeserializeAnonymousField(t *testing.T) {
} }
func TestStringIntConvert(t *testing.T) { func TestStringIntConvert(t *testing.T) {
s := "127"
test := struct { test := struct {
i8 int8 I8 int8
i16 int16 I16 int16
i32 int32 I32 int32
i64 int64 I64 int64
u8 uint8 U8 uint8
u16 uint16 U16 uint16
u32 uint32 U32 uint32
u64 uint64 U64 uint64
}{} }{}
ok, err := ConvertString(s, reflect.ValueOf(&test.i8)) refl := reflect.ValueOf(&test)
for i := range refl.Elem().NumField() {
field := refl.Elem().Field(i)
t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) {
ok, err := ConvertString("127", field)
ExpectTrue(t, ok) ExpectTrue(t, ok)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqualValues(t, test.i8, int8(127)) ExpectEqualValues(t, field.Interface(), 127)
ok, err = ConvertString(s, reflect.ValueOf(&test.i16)) err = Convert(reflect.ValueOf(uint8(64)), field)
ExpectTrue(t, ok)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqualValues(t, test.i16, int16(127)) ExpectEqualValues(t, field.Interface(), 64)
})
ok, err = ConvertString(s, reflect.ValueOf(&test.i32)) }
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i32, int32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i64, int64(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u8, uint8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u16, uint16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u32, uint32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u64, uint64(127))
} }
type testModel struct { type testModel struct {
@ -164,8 +138,8 @@ func TestConvertor(t *testing.T) {
ExpectEqualValues(t, m.Test.foo, 123) ExpectEqualValues(t, m.Test.foo, 123)
ExpectEqualValues(t, m.Test.bar, "123") ExpectEqualValues(t, m.Test.bar, "123")
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m)) ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m))
ExpectEqualValues(t, m.Baz, "123") ExpectEqualValues(t, m.Baz, "456")
}) })
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
@ -237,18 +211,28 @@ func BenchmarkStringToMapYAML(b *testing.B) {
} }
func TestStringToStruct(t *testing.T) { func TestStringToStruct(t *testing.T) {
t.Run("yaml-like", func(t *testing.T) { type T struct {
dst := struct {
A string A string
B int B int
}{} }
t.Run("yaml-like simple", func(t *testing.T) {
var dst T
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst)) convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
ExpectTrue(t, convertible) ExpectTrue(t, convertible)
ExpectNoError(t, err) ExpectNoError(t, err)
ExpectEqualValues(t, dst, struct { ExpectEqualValues(t, dst.A, "a")
A string ExpectEqualValues(t, dst.B, 123)
B int })
}{"a", 123})
type T2 struct {
URL *url.URL
CIDR *net.IPNet
}
t.Run("yaml-like complex", func(t *testing.T) {
var dst T2
convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
}) })
} }