mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix: unmarshal and some tests
This commit is contained in:
parent
be87d47ebb
commit
3f2dfe14b5
6 changed files with 100 additions and 101 deletions
|
@ -2,6 +2,7 @@ package gperr
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
|
@ -64,7 +65,7 @@ func (err *withSubject) Prepend(subject string) *withSubject {
|
|||
}
|
||||
|
||||
func (err *withSubject) Is(other error) bool {
|
||||
return err.Err == other
|
||||
return errors.Is(other, err.Err)
|
||||
}
|
||||
|
||||
func (err *withSubject) Unwrap() error {
|
||||
|
|
|
@ -23,7 +23,7 @@ type (
|
|||
Key, Value string
|
||||
}
|
||||
Host string
|
||||
CIDR struct{ net.IPNet }
|
||||
CIDR net.IPNet
|
||||
)
|
||||
|
||||
var ErrInvalidHTTPHeaderFilter = gperr.New("invalid http header filter")
|
||||
|
@ -85,7 +85,7 @@ func (h Host) Fulfill(req *http.Request, res *http.Response) bool {
|
|||
return req.Host == string(h)
|
||||
}
|
||||
|
||||
func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
func (cidr *CIDR) Fulfill(req *http.Request, res *http.Response) bool {
|
||||
ip, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = req.RemoteAddr
|
||||
|
@ -94,5 +94,9 @@ func (cidr CIDR) Fulfill(req *http.Request, res *http.Response) bool {
|
|||
if netIP == nil {
|
||||
return false
|
||||
}
|
||||
return cidr.Contains(netIP)
|
||||
return (*net.IPNet)(cidr).Contains(netIP)
|
||||
}
|
||||
|
||||
func (cidr *CIDR) String() string {
|
||||
return (*net.IPNet)(cidr).String()
|
||||
}
|
||||
|
|
|
@ -157,11 +157,9 @@ func TestHeaderFilter(t *testing.T) {
|
|||
|
||||
func TestCIDRFilter(t *testing.T) {
|
||||
cidr := []*CIDR{{
|
||||
net.IPNet{
|
||||
IP: net.ParseIP("192.168.10.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
}},
|
||||
}
|
||||
IP: net.ParseIP("192.168.10.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
}}
|
||||
ExpectEqual(t, cidr[0].String(), "192.168.10.0/24")
|
||||
inCIDR := &http.Request{
|
||||
RemoteAddr: "192.168.10.1",
|
||||
|
|
|
@ -39,6 +39,7 @@ func TestHTTPConfigDeserialize(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := Route{}
|
||||
tt.input["host"] = "internal"
|
||||
err := utils.MapUnmarshalValidate(tt.input, &cfg)
|
||||
if err != nil {
|
||||
ExpectNoError(t, err)
|
||||
|
|
|
@ -47,11 +47,15 @@ var (
|
|||
|
||||
var (
|
||||
typeDuration = reflect.TypeFor[time.Duration]()
|
||||
typeTime = reflect.TypeFor[time.Time]()
|
||||
typeURL = reflect.TypeFor[url.URL]()
|
||||
typeCIDR = reflect.TypeFor[*net.IPNet]()
|
||||
typeCIDR = reflect.TypeFor[net.IPNet]()
|
||||
|
||||
typeMapMarshaller = reflect.TypeFor[MapMarshaller]()
|
||||
typeMapUnmarshaler = reflect.TypeFor[MapUnmarshaller]()
|
||||
typeJSONMarshaller = reflect.TypeFor[json.Marshaler]()
|
||||
|
||||
typeAny = reflect.TypeOf((*any)(nil)).Elem()
|
||||
)
|
||||
|
||||
var defaultValues = functional.NewMapOf[reflect.Type, func() any]()
|
||||
|
@ -360,20 +364,26 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
|
|||
return err
|
||||
}
|
||||
case isIntFloat(srcKind):
|
||||
var strV string
|
||||
switch {
|
||||
case src.CanInt():
|
||||
strV = strconv.FormatInt(src.Int(), 10)
|
||||
case srcKind == reflect.Bool:
|
||||
strV = strconv.FormatBool(src.Bool())
|
||||
case src.CanUint():
|
||||
strV = strconv.FormatUint(src.Uint(), 10)
|
||||
case src.CanFloat():
|
||||
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
|
||||
if dst.Kind() == reflect.String {
|
||||
var strV string
|
||||
switch {
|
||||
case src.CanInt():
|
||||
strV = strconv.FormatInt(src.Int(), 10)
|
||||
case srcKind == reflect.Bool:
|
||||
strV = strconv.FormatBool(src.Bool())
|
||||
case src.CanUint():
|
||||
strV = strconv.FormatUint(src.Uint(), 10)
|
||||
case src.CanFloat():
|
||||
strV = strconv.FormatFloat(src.Float(), 'f', -1, 64)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(strV))
|
||||
return nil
|
||||
}
|
||||
if convertible, err := ConvertString(strV, dst); convertible {
|
||||
return err
|
||||
if !isIntFloat(dstT.Kind()) || !src.CanConvert(dstT) {
|
||||
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
|
||||
}
|
||||
dst.Set(src.Convert(dstT))
|
||||
return nil
|
||||
case srcKind == reflect.Map:
|
||||
if src.Len() == 0 {
|
||||
return nil
|
||||
|
@ -412,8 +422,17 @@ func Convert(src reflect.Value, dst reflect.Value) gperr.Error {
|
|||
return ErrUnsupportedConversion.Subjectf("%s to %s", srcT, dstT)
|
||||
}
|
||||
|
||||
func nilPointer[T any]() reflect.Value {
|
||||
return reflect.ValueOf((*T)(nil))
|
||||
func isSameOrEmbededType(src, dst reflect.Type) bool {
|
||||
return src == dst || src.ConvertibleTo(dst)
|
||||
}
|
||||
|
||||
func setSameOrEmbedddType(src, dst reflect.Value) {
|
||||
dstT := dst.Type()
|
||||
if src.Type().AssignableTo(dstT) {
|
||||
dst.Set(src)
|
||||
} else {
|
||||
dst.Set(src.Convert(dstT))
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gperr.Error) {
|
||||
|
@ -430,12 +449,12 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
|
|||
dst.SetString(src)
|
||||
return
|
||||
}
|
||||
switch dstT {
|
||||
case typeDuration:
|
||||
if src == "" {
|
||||
dst.Set(reflect.Zero(dstT))
|
||||
return false, nil
|
||||
}
|
||||
if src == "" {
|
||||
dst.Set(reflect.Zero(dstT))
|
||||
return
|
||||
}
|
||||
switch {
|
||||
case dstT == typeDuration:
|
||||
d, err := time.ParseDuration(src)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
|
@ -445,30 +464,22 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr gpe
|
|||
}
|
||||
dst.Set(reflect.ValueOf(d))
|
||||
return
|
||||
case typeURL:
|
||||
if src == "" {
|
||||
dst.Addr().Set(nilPointer[*url.URL]())
|
||||
return
|
||||
}
|
||||
case isSameOrEmbededType(dstT, typeURL):
|
||||
u, err := url.Parse(src)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(u).Elem())
|
||||
setSameOrEmbedddType(reflect.ValueOf(u).Elem(), dst)
|
||||
return
|
||||
case typeCIDR:
|
||||
if src == "" {
|
||||
dst.Addr().Set(nilPointer[*net.IPNet]())
|
||||
return
|
||||
}
|
||||
if !strings.Contains(src, "/") {
|
||||
case isSameOrEmbededType(dstT, typeCIDR):
|
||||
if !strings.ContainsRune(src, '/') {
|
||||
src += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(src)
|
||||
if err != nil {
|
||||
return true, gperr.Wrap(err)
|
||||
}
|
||||
dst.Set(reflect.ValueOf(ipnet).Elem())
|
||||
setSameOrEmbedddType(reflect.ValueOf(ipnet).Elem(), dst)
|
||||
return
|
||||
}
|
||||
if dstKind := dst.Kind(); isIntFloat(dstKind) {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
@ -12,7 +14,7 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestDeserialize(t *testing.T) {
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
type S struct {
|
||||
I int
|
||||
S string
|
||||
|
@ -41,7 +43,7 @@ func TestDeserialize(t *testing.T) {
|
|||
}
|
||||
)
|
||||
|
||||
t.Run("deserialize", func(t *testing.T) {
|
||||
t.Run("unmarshal", func(t *testing.T) {
|
||||
var s2 S
|
||||
err := MapUnmarshalValidate(testStructSerialized, &s2)
|
||||
ExpectNoError(t, err)
|
||||
|
@ -49,7 +51,7 @@ func TestDeserialize(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDeserializeAnonymousField(t *testing.T) {
|
||||
func TestUnmarshalAnonymousField(t *testing.T) {
|
||||
type Anon struct {
|
||||
A, B int
|
||||
}
|
||||
|
@ -77,59 +79,31 @@ func TestDeserializeAnonymousField(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestStringIntConvert(t *testing.T) {
|
||||
s := "127"
|
||||
|
||||
test := struct {
|
||||
i8 int8
|
||||
i16 int16
|
||||
i32 int32
|
||||
i64 int64
|
||||
u8 uint8
|
||||
u16 uint16
|
||||
u32 uint32
|
||||
u64 uint64
|
||||
I8 int8
|
||||
I16 int16
|
||||
I32 int32
|
||||
I64 int64
|
||||
U8 uint8
|
||||
U16 uint16
|
||||
U32 uint32
|
||||
U64 uint64
|
||||
}{}
|
||||
|
||||
ok, err := ConvertString(s, reflect.ValueOf(&test.i8))
|
||||
|
||||
refl := reflect.ValueOf(&test)
|
||||
for i := range refl.Elem().NumField() {
|
||||
field := refl.Elem().Field(i)
|
||||
t.Run(fmt.Sprintf("field_%s", field.Type().Name()), func(t *testing.T) {
|
||||
ok, err := ConvertString("127", field)
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.i8, int8(127))
|
||||
ExpectEqualValues(t, field.Interface(), 127)
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i16))
|
||||
ExpectTrue(t, ok)
|
||||
err = Convert(reflect.ValueOf(uint8(64)), field)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.i16, int16(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i32))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.i32, int32(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.i64))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.i64, int64(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u8))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.u8, uint8(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u16))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.u16, uint16(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u32))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.u32, uint32(127))
|
||||
|
||||
ok, err = ConvertString(s, reflect.ValueOf(&test.u64))
|
||||
ExpectTrue(t, ok)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, test.u64, uint64(127))
|
||||
ExpectEqualValues(t, field.Interface(), 64)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testModel struct {
|
||||
|
@ -164,8 +138,8 @@ func TestConvertor(t *testing.T) {
|
|||
ExpectEqualValues(t, m.Test.foo, 123)
|
||||
ExpectEqualValues(t, m.Test.bar, "123")
|
||||
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 123}, m))
|
||||
ExpectEqualValues(t, m.Baz, "123")
|
||||
ExpectNoError(t, MapUnmarshalValidate(map[string]any{"Baz": 456}, m))
|
||||
ExpectEqualValues(t, m.Baz, "456")
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
|
@ -237,18 +211,28 @@ func BenchmarkStringToMapYAML(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestStringToStruct(t *testing.T) {
|
||||
t.Run("yaml-like", func(t *testing.T) {
|
||||
dst := struct {
|
||||
type T struct {
|
||||
A string
|
||||
B int
|
||||
}{}
|
||||
}
|
||||
t.Run("yaml-like simple", func(t *testing.T) {
|
||||
var dst T
|
||||
convertible, err := ConvertString(" A: a\n B: 123", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
ExpectEqualValues(t, dst, struct {
|
||||
A string
|
||||
B int
|
||||
}{"a", 123})
|
||||
ExpectEqualValues(t, dst.A, "a")
|
||||
ExpectEqualValues(t, dst.B, 123)
|
||||
})
|
||||
|
||||
type T2 struct {
|
||||
URL *url.URL
|
||||
CIDR *net.IPNet
|
||||
}
|
||||
t.Run("yaml-like complex", func(t *testing.T) {
|
||||
var dst T2
|
||||
convertible, err := ConvertString(" URL: http://example.com\n CIDR: 1.2.3.0/24", reflect.ValueOf(&dst))
|
||||
ExpectTrue(t, convertible)
|
||||
ExpectNoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue