fix: unmarshal and some tests

This commit is contained in:
yusing 2025-04-13 12:24:11 +08:00
parent be87d47ebb
commit 3f2dfe14b5
6 changed files with 100 additions and 101 deletions

View file

@ -2,6 +2,7 @@ package gperr
import (
"encoding/json"
"errors"
"slices"
"strings"
@ -64,7 +65,7 @@ func (err *withSubject) Prepend(subject string) *withSubject {
}
func (err *withSubject) Is(other error) bool {
return err.Err == other
return errors.Is(other, err.Err)
}
func (err *withSubject) Unwrap() error {

View file

@ -23,7 +23,7 @@ type (
Key, Value string
}
Host string
CIDR struct{ net.IPNet }
CIDR net.IPNet
)
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
@ -85,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
return req.Host == string(h)
}
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
ip = req.RemoteAddr
@ -94,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
if netIP == nil {
return false
}
return cidr.Contains(netIP)
return (*net.IPNet)(cidr).Contains(netIP)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}

View file

@ -157,11 +157,9 @@ func TestHeaderFilter(t *testing.T) {
func TestCIDRFilter(t *testing.T) {
cidr := []*CIDR{{
net.IPNet{
IP: net.ParseIP("192.168.10.0"),
Mask: net.CIDRMask(24, 32),
}},
}
IP: net.ParseIP("192.168.10.0"),
Mask: net.CIDRMask(24, 32),
}}
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
inCIDR := &http.Request{
RemoteAddr: "192.168.10.1",

View file

@ -39,6 +39,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Route{}
tt.input["host"] = "internal"
err := utils.MapUnmarshalValidate(tt.input, &cfg)
if err != nil {
ExpectNoError(t, err)

View file

@ -47,11 +47,15 @@ var (
var (
typeDuration = reflect.TypeFor[time.Duration]()
typeTime = reflect.TypeFor[time.Time]()
typeURL = reflect.TypeFor[url.URL]()
typeCIDR = reflect.TypeFor[*net.IPNet]()
typeCIDR = reflect.TypeFor[net.IPNet]()
typeMapMarshaller = reflect.TypeFor[MapMarshaller]()
typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]()
typeJSONMarshaller = reflect.TypeFor[json.Marshaler]()
typeAny = reflect.TypeOf((*any)(nil)).Elem()
)
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
@ -360,20 +364,26 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return err
}
case isIntFloat(srcKind):
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
if dst.Kind() == reflect.String {
var strV string
switch {
case src.CanInt():
strV = strconv.FormatInt(src.Int(), 10)
case srcKind == reflect.Bool:
strV = strconv.FormatBool(src.Bool())
case src.CanUint():
strV = strconv.FormatUint(src.Uint(), 10)
case src.CanFloat():
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
}
dst.Set(reflect.ValueOf(strV))
return nil
}
if convertible, err := ConvertString(strV, dst); convertible {
return err
if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) {
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
dst.Set(src.Convert(dstT))
return nil
case srcKind == reflect.Map:
if src.Len() == 0 {
return nil
@ -412,8 +422,17 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
}
func nilPointer[T any]() reflect.Value {
return reflect.ValueOf((*T)(nil))
func isSameOrEmbededType(src, dst reflect.Type) bool {
return src == dst || src.ConvertibleTo(dst)
}
func setSameOrEmbedddType(src, dst reflect.Value) {
dstT := dst.Type()
if src.Type().AssignableTo(dstT) {
dst.Set(src)
} else {
dst.Set(src.Convert(dstT))
}
}
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
@ -430,12 +449,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
dst.SetString(src)
return
}
switch dstT {
case typeDuration:
if src == "" {
dst.Set(reflect.Zero(dstT))
return false, nil
}
if src == "" {
dst.Set(reflect.Zero(dstT))
return
}
switch {
case dstT == typeDuration:
d, err := time.ParseDuration(src)
if err != nil {
return true, gperr.Wrap(err)
@ -445,30 +464,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
}
dst.Set(reflect.ValueOf(d))
return
case typeURL:
if src == "" {
dst.Addr().Set(nilPointer[*url.URL]())
return
}
case isSameOrEmbededType(dstT, typeURL):
u, err := url.Parse(src)
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(u).Elem())
setSameOrEmbedddType(reflect.ValueOf(u).Elem(), dst)
return
case typeCIDR:
if src == "" {
dst.Addr().Set(nilPointer[*net.IPNet]())
return
}
if !strings.Contains(src, "/") {
case isSameOrEmbededType(dstT, typeCIDR):
if !strings.ContainsRune(src, '/') {
src += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(src)
if err != nil {
return true, gperr.Wrap(err)
}
dst.Set(reflect.ValueOf(ipnet).Elem())
setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst)
return
}
if dstKind := dst.Kind(); isIntFloat(dstKind) {

View file

@ -1,6 +1,8 @@
package utils
import (
"fmt"
"net"
"net/url"
"reflect"
"strconv"
@ -12,7 +14,7 @@ import (
"gopkg.in/yaml.v3"
)
func TestDeserialize(t *testing.T) {
func TestUnmarshal(t *testing.T) {
type S struct {
I int
S string
@ -41,7 +43,7 @@ func TestDeserialize(t *testing.T) {
}
)
t.Run("deserialize", func(t *testing.T) {
t.Run("unmarshal", func(t *testing.T) {
var s2 S
err := MapUnmarshalValidate(testStructSerialized, &s2)
ExpectNoError(t, err)
@ -49,7 +51,7 @@ func TestDeserialize(t *testing.T) {
})
}
func TestDeserializeAnonymousField(t *testing.T) {
func TestUnmarshalAnonymousField(t *testing.T) {
type Anon struct {
A, B int
}
@ -77,59 +79,31 @@ func TestDeserializeAnonymousField(t *testing.T) {
}
func TestStringIntConvert(t *testing.T) {
s := "127"
test := struct {
i8 int8
i16 int16
i32 int32
i64 int64
u8 uint8
u16 uint16
u32 uint32
u64 uint64
I8 int8
I16 int16
I32 int32
I64 int64
U8 uint8
U16 uint16
U32 uint32
U64 uint64
}{}
ok, err := ConvertString(s, reflect.ValueOf(&test.i8))
refl := reflect.ValueOf(&test)
for i := range refl.Elem().NumField() {
field := refl.Elem().Field(i)
t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) {
ok, err := ConvertString("127", field)
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i8, int8(127))
ExpectEqualValues(t, field.Interface(), 127)
ok, err = ConvertString(s, reflect.ValueOf(&test.i16))
ExpectTrue(t, ok)
err = Convert(reflect.ValueOf(uint8(64)), field)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i16, int16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i32))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i32, int32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.i64, int64(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u8, uint8(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u16, uint16(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u32, uint32(127))
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
ExpectTrue(t, ok)
ExpectNoError(t, err)
ExpectEqualValues(t, test.u64, uint64(127))
ExpectEqualValues(t, field.Interface(), 64)
})
}
}
type testModel struct {
@ -164,8 +138,8 @@ func TestConvertor(t *testing.T) {
ExpectEqualValues(t, m.Test.foo, 123)
ExpectEqualValues(t, m.Test.bar, "123")
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m))
ExpectEqualValues(t, m.Baz, "123")
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m))
ExpectEqualValues(t, m.Baz, "456")
})
t.Run("invalid", func(t *testing.T) {
@ -237,18 +211,28 @@ func BenchmarkStringToMapYAML(b *testing.B) {
}
func TestStringToStruct(t *testing.T) {
t.Run("yaml-like", func(t *testing.T) {
dst := struct {
type T struct {
A string
B int
}{}
}
t.Run("yaml-like simple", func(t *testing.T) {
var dst T
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
ExpectEqualValues(t, dst, struct {
A string
B int
}{"a", 123})
ExpectEqualValues(t, dst.A, "a")
ExpectEqualValues(t, dst.B, 123)
})
type T2 struct {
URL *url.URL
CIDR *net.IPNet
}
t.Run("yaml-like complex", func(t *testing.T) {
var dst T2
convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst))
ExpectTrue(t, convertible)
ExpectNoError(t, err)
})
}