implement middleware compose

This commit is contained in:
yusing 2024-10-01 16:38:07 +08:00
parent f5a36f94bb
commit 44cfd65f6c
17 changed files with 392 additions and 152 deletions

View file

@ -11,6 +11,8 @@
A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI.
_Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
## Table of content
<!-- TOC -->

View file

@ -2,6 +2,7 @@ package error
import (
"fmt"
"strings"
"sync"
)
@ -24,7 +25,6 @@ func NewBuilder(format string, args ...any) Builder {
func (b Builder) Add(err NestedError) Builder {
if err != nil {
b.Lock()
// TODO: if err severity is higher than b.severity, update b.severity
b.errors = append(b.errors, err)
b.Unlock()
}
@ -49,6 +49,8 @@ func (b Builder) Addf(format string, args ...any) Builder {
func (b Builder) Build() NestedError {
if len(b.errors) == 0 {
return nil
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
return b.errors[0].Subject(b.message)
}
return Join(b.message, b.errors...)
}

View file

@ -166,6 +166,8 @@ func (ne NestedError) Subject(s any) NestedError {
}
if ne.subject == "" {
ne.subject = subject
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
} else {
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
}

View file

@ -6,15 +6,16 @@ import (
)
var (
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrFailure = stderrors.New("failed")
ErrInvalid = stderrors.New("invalid")
ErrUnsupported = stderrors.New("unsupported")
ErrUnexpected = stderrors.New("unexpected")
ErrNotExists = stderrors.New("does not exist")
ErrMissing = stderrors.New("missing")
ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
)
const fmtSubjectWhat = "%w %v: %q"
@ -63,6 +64,14 @@ func OutOfRange(subject any, value any) NestedError {
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
}
func TypeError(subject any, from, to reflect.Value) NestedError {
return errorf("%v %w: %T -> %T", subject, ErrTypeError, from.Interface(), to.Interface())
func TypeError(subject any, from, to reflect.Type) NestedError {
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) NestedError {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
}

View file

@ -13,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
)
const (
@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
return cri.m, nil
}
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
}
if common.IsTest {
cfCIDRs = []*net.IPNet{
cfCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
} else {
cfCIDRs = make([]*net.IPNet, 0, 30)
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, cidr)
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
}
}

View file

@ -1,6 +1,7 @@
package middleware
import (
"encoding/json"
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
@ -53,7 +54,14 @@ func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err
@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
continue
}
m, err := m.WithOptionsClone(opts, rp)
m, err := m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue

View file

@ -8,24 +8,27 @@ import (
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) {
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
}
return BuildMiddlewaresFromYAML(fileContent)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
var data map[string][]map[string]any
fileContent, err := os.ReadFile(filePath)
if err != nil {
b.Add(E.FailWith("read file", err))
return
}
err = yaml.Unmarshal(fileContent, &data)
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
for name, defs := range data {
chainErr := E.NewBuilder("errors in middleware chain %s", name)
for name, defs := range rawMap {
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa
continue
}
delete(def, "use")
m, err := base.withOptions(def)
m, err := base.WithOptionsClone(def)
if err != nil {
chainErr.Add(err.Subjectf("%s.%d", name, i))
chainErr.Add(err.Subjectf("item%d", i))
continue
}
chain = append(chain, m)

View file

@ -1,9 +1,21 @@
package middleware
import (
_ "embed"
"encoding/json"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBuild(t *testing.T) {
//go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
t.Log(string(data))
}

View file

@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"net/http"
"path"
"strings"
@ -15,27 +16,27 @@ import (
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
middleware, ok = middlewares[strings.ToLower(name)]
return
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"hide_x_forwarded": HideXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
"real_ip": RealIP.m,
"cloudflare_real_ip": CloudflareRealIP.m,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
names[m] = append(names[m], http.CanonicalHeaderKey(name))
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
@ -49,6 +50,7 @@ func init() {
m.name = names[0]
}
}
// TODO: seperate from init()
b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
@ -57,7 +59,7 @@ func init() {
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromYAML(defFile)
mws, err := BuildMiddlewaresFromComposeFile(defFile)
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))

View file

@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -2,11 +2,11 @@ package middleware
import (
"net"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
@ -20,7 +20,7 @@ type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string
// From is a list of Address / CIDRs to trust
From []*net.IPNet
From []*types.CIDR
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@ -35,7 +35,7 @@ type realIPOpts struct {
var RealIP = &realIP{
m: &Middleware{
labelParserMap: D.ValueParserMap{
"from": CIDRListParser,
"from": D.YamlStringListParser,
"recursive": D.BoolParser,
},
withOptions: NewRealIP,
@ -45,14 +45,7 @@ var RealIP = &realIP{
var realIPOptsDefault = func() *realIPOpts {
return &realIPOpts{
Header: "X-Real-IP",
From: []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)},
{IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)},
},
From: []*types.CIDR{},
}
}
@ -72,31 +65,6 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
return riWithOpts.m, nil
}
func CIDRListParser(s string) (any, E.NestedError) {
sl, err := D.YamlStringListParser(s)
if err != nil {
return nil, err
}
b := E.NewBuilder("invalid CIDR(s)")
CIDRs := sl.([]string)
res := make([]*net.IPNet, 0, len(CIDRs))
for _, cidr := range CIDRs {
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
b.Add(E.Invalid("CIDR", cidr))
continue
}
res = append(res, ipnet)
}
return res, b.Build()
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From {
if CIDR.Contains(ip) {

View file

@ -0,0 +1,58 @@
package middleware
import (
"net"
"testing"
"github.com/yusing/go-proxy/internal/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIP(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
},
"recursive": true,
}
optExpected := &realIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),
},
{
IP: net.ParseIP("192.168.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
{
IP: net.ParseIP("172.16.0.0"),
Mask: net.IPv4Mask(255, 240, 0, 0),
},
},
Recursive: true,
}
t.Run("set_options", func(t *testing.T) {
ri, err := RealIP.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// t.Run("request_headers", func(t *testing.T) {
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
// middlewareOpt: opts,
// })
// ExpectNoError(t, err.Error())
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
// })
}

View file

@ -0,0 +1,41 @@
theGreatPretender:
- use: HideXForwarded
- use: ModifyRequest
setHeaders:
X-Real-IP: 6.6.6.6
- use: ModifyResponse
hideHeaders:
- X-Test3
- X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
header: X-Real-IP
from:
- "127.0.0.0/8"
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP:
- use: ModifyRequest
setHeaders:
CF-Connecting-IP: 127.0.0.1
- use: CloudflareRealIP

34
internal/types/cidr.go Normal file
View file

@ -0,0 +1,34 @@
package types
import (
"net"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type CIDR net.IPNet
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
cidr, ok := val.(string)
if !ok {
return nil, E.TypeMismatch[string](val)
}
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, E.Invalid("CIDR", cidr)
}
return (*CIDR)(ipnet), nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {
return (*net.IPNet)(cidr).Contains(ip)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}

View file

@ -12,6 +12,11 @@ import (
"gopkg.in/yaml.v3"
)
type SerializedObject = map[string]any
type Convertor interface {
ConvertFrom(value any) (any, E.NestedError)
}
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
var i any
@ -89,7 +94,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
} else if field.Anonymous {
// If the field is an embedded struct, add its fields to the result
fieldMap, err := Serialize(value.Field(i).Interface())
if err.HasError() {
if err != nil {
return nil, err
}
for k, v := range fieldMap {
@ -106,90 +111,138 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
return result, nil
}
func Deserialize(src SerializedObject, target any) E.NestedError {
if src == nil || target == nil {
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields.
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.NestedError {
if src == nil || dst == nil {
return nil
}
tValue := reflect.ValueOf(target)
mapping := make(map[string]string)
dstV := reflect.ValueOf(dst)
dstT := dstV.Type()
if tValue.Kind() == reflect.Ptr {
tValue = tValue.Elem()
if dstV.Kind() == reflect.Ptr {
dstV = dstV.Elem()
dstT = dstV.Type()
}
// convert data fields to lower no-snake
// convert target fields to lower no-snake
// then check if the field of data is in the target
if tValue.Kind() == reflect.Struct {
t := reflect.TypeOf(target).Elem()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
snakeCaseField := ToLowerNoSnake(field.Name)
mapping[snakeCaseField] = field.Name
// TODO: use E.Builder to collect errors from all fields
if dstV.Kind() == reflect.Struct {
mapping := make(map[string]reflect.Value)
for i := 0; i < dstV.NumField(); i++ {
field := dstT.Field(i)
mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i)
}
} else if tValue.Kind() == reflect.Map && tValue.Type().Key().Kind() == reflect.String {
if tValue.IsNil() {
tValue.Set(reflect.MakeMap(tValue.Type()))
for k, v := range src {
if field, ok := mapping[ToLowerNoSnake(k)]; ok {
err := Convert(reflect.ValueOf(v), field)
if err != nil {
return err.Subject(k)
}
} else {
return E.Unexpected("field", k)
}
}
} else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String {
if dstV.IsNil() {
dstV.Set(reflect.MakeMap(dstT))
}
for k := range src {
// TODO: type check
tValue.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), reflect.ValueOf(src[k]))
tmp := reflect.New(dstT.Elem()).Elem()
err := Convert(reflect.ValueOf(src[k]), tmp)
if err != nil {
return err.Subject(k)
}
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp)
}
return nil
} else {
return E.Unsupported("target type", fmt.Sprintf("%T", target))
return E.Unsupported("target type", fmt.Sprintf("%T", dst))
}
for k, v := range src {
kCleaned := ToLowerNoSnake(k)
if fieldName, ok := mapping[kCleaned]; ok {
prop := tValue.FieldByName(fieldName)
propType := prop.Type()
isPtr := prop.Kind() == reflect.Ptr
if prop.CanSet() {
val := reflect.ValueOf(v)
vType := val.Type()
switch {
case isPtr && vType.ConvertibleTo(propType.Elem()):
ptr := reflect.New(propType.Elem())
ptr.Elem().Set(val.Convert(propType.Elem()))
prop.Set(ptr)
case vType.ConvertibleTo(propType):
prop.Set(val.Convert(propType))
case isPtr:
var vSerialized SerializedObject
vSerialized, ok = v.(SerializedObject)
if !ok {
if vType.ConvertibleTo(reflect.TypeFor[SerializedObject]()) {
vSerialized = val.Convert(reflect.TypeFor[SerializedObject]()).Interface().(SerializedObject)
} else {
return E.Failure(fmt.Sprintf("convert %s (%T) to %s", k, v, reflect.TypeFor[SerializedObject]()))
}
}
propNew := reflect.New(propType.Elem())
err := Deserialize(vSerialized, propNew.Interface())
if err.HasError() {
return E.Failure("set field").With(err).Subject(k)
}
prop.Set(propNew)
default:
obj, ok := val.Interface().(SerializedObject)
if !ok {
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
}
err := Deserialize(obj, prop.Addr().Interface())
if err.HasError() {
return E.Failure("set field").With(err).Subject(k)
}
}
} else {
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
}
} else {
return E.Unexpected("field", k)
return nil
}
// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
// If src is a slice, each of its elements are converted and stored in dst.
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
//
// If dst is not settable, an error is returned.
// If src cannot be converted to dst, an error is returned.
// If any error occurs during conversion (e.g. deserialization), it is returned.
//
// Returns:
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
srcT := src.Type()
dstVT := dst.Type()
if src.Kind() == reflect.Interface {
src = src.Elem()
srcT = src.Type()
}
if !dst.CanSet() {
return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface()))
}
switch {
case srcT.AssignableTo(dstVT):
dst.Set(src)
case srcT.ConvertibleTo(dstVT):
dst.Set(src.Convert(dstVT))
case srcT.Kind() == reflect.Map:
if dstVT.Kind() != reflect.Map {
return E.TypeError("map", srcT, dstVT)
}
obj, ok := src.Interface().(SerializedObject)
if !ok {
return E.TypeError("map", srcT, dstVT)
}
err := Deserialize(obj, dst.Addr().Interface())
if err != nil {
return err
}
case srcT.Kind() == reflect.Slice:
if dstVT.Kind() != reflect.Slice {
return E.TypeError("slice", srcT, dstVT)
}
newSlice := reflect.MakeSlice(dstVT, 0, src.Len())
i := 0
for _, v := range src.Seq2() {
tmp := reflect.New(dstVT.Elem()).Elem()
err := Convert(v, tmp)
if err != nil {
return err.Subjectf("[%d]", i)
}
newSlice = reflect.Append(newSlice, tmp)
i++
}
dst.Set(newSlice)
default:
// check if Convertor is implemented
if converter, ok := dst.Interface().(Convertor); ok {
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
}
dst.Set(reflect.ValueOf(converted))
return nil
}
return E.TypeError("conversion", srcT, dstVT)
}
return nil
@ -197,7 +250,7 @@ func Deserialize(src SerializedObject, target any) E.NestedError {
func DeserializeJson(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j))
if err.HasError() {
if err != nil {
return err
}
return E.From(json.Unmarshal(data, target))
@ -206,5 +259,3 @@ func DeserializeJson(j map[string]string, target any) E.NestedError {
func ToLowerNoSnake(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
}
type SerializedObject = map[string]any

View file

@ -0,0 +1,47 @@
package utils
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
type S = struct {
I int
S string
IS []int
SS []string
MSI map[string]int
MIS map[int]string
}
var testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
SS: []string{"a", "b", "c"},
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
var testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
"SS": []string{"a", "b", "c"},
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
func TestSerialize(t *testing.T) {
s, err := Serialize(testStruct)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, s, testStructSerialized)
}
func TestDeserialize(t *testing.T) {
var s S
err := Deserialize(testStructSerialized, &s)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, s, testStruct)
}