implement middleware compose

This commit is contained in:
yusing 2024-10-01 16:38:07 +08:00
parent f5a36f94bb
commit 44cfd65f6c
17 changed files with 392 additions and 152 deletions

View file

@ -11,6 +11,8 @@
A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI. A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI.
_Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
## Table of content ## Table of content
<!-- TOC --> <!-- TOC -->

View file

@ -2,6 +2,7 @@ package error
import ( import (
"fmt" "fmt"
"strings"
"sync" "sync"
) )
@ -24,7 +25,6 @@ func NewBuilder(format string, args ...any) Builder {
func (b Builder) Add(err NestedError) Builder { func (b Builder) Add(err NestedError) Builder {
if err != nil { if err != nil {
b.Lock() b.Lock()
// TODO: if err severity is higher than b.severity, update b.severity
b.errors = append(b.errors, err) b.errors = append(b.errors, err)
b.Unlock() b.Unlock()
} }
@ -49,6 +49,8 @@ func (b Builder) Addf(format string, args ...any) Builder {
func (b Builder) Build() NestedError { func (b Builder) Build() NestedError {
if len(b.errors) == 0 { if len(b.errors) == 0 {
return nil return nil
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
return b.errors[0].Subject(b.message)
} }
return Join(b.message, b.errors...) return Join(b.message, b.errors...)
} }

View file

@ -166,6 +166,8 @@ func (ne NestedError) Subject(s any) NestedError {
} }
if ne.subject == "" { if ne.subject == "" {
ne.subject = subject ne.subject = subject
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
} else { } else {
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject) ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
} }

View file

@ -15,6 +15,7 @@ var (
ErrDuplicated = stderrors.New("duplicated") ErrDuplicated = stderrors.New("duplicated")
ErrOutOfRange = stderrors.New("out of range") ErrOutOfRange = stderrors.New("out of range")
ErrTypeError = stderrors.New("type error") ErrTypeError = stderrors.New("type error")
ErrTypeMismatch = stderrors.New("type mismatch")
) )
const fmtSubjectWhat = "%w %v: %q" const fmtSubjectWhat = "%w %v: %q"
@ -63,6 +64,14 @@ func OutOfRange(subject any, value any) NestedError {
return errorf("%v %w: %v", subject, ErrOutOfRange, value) return errorf("%v %w: %v", subject, ErrOutOfRange, value)
} }
func TypeError(subject any, from, to reflect.Value) NestedError { func TypeError(subject any, from, to reflect.Type) NestedError {
return errorf("%v %w: %T -> %T", subject, ErrTypeError, from.Interface(), to.Interface()) return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
}
func TypeError2(subject any, from, to reflect.Value) NestedError {
return TypeError(subject, from.Type(), to.Type())
}
func TypeMismatch[Expect any](value any) NestedError {
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
) )
const ( const (
@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
return cri.m, nil return cri.m, nil
} }
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) { func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval { if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return return
} }
@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
} }
if common.IsTest { if common.IsTest {
cfCIDRs = []*net.IPNet{ cfCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)}, {IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)}, {IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)}, {IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)}, {IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
} }
} else { } else {
cfCIDRs = make([]*net.IPNet, 0, 30) cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join( err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs), fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs), fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
return return
} }
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error { func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
resp, err := http.Get(endpoint) resp, err := http.Get(endpoint)
if err != nil { if err != nil {
return err return err
@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
if err != nil { if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line) return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else { } else {
cfCIDRs = append(cfCIDRs, cidr) cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
} }
} }

View file

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"encoding/json"
"net/http" "net/http"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
@ -53,7 +54,14 @@ func (m *Middleware) String() string {
return m.name return m.name
} }
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) { func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil { if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil { if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err return nil, err
@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
continue continue
} }
m, err := m.WithOptionsClone(opts, rp) m, err := m.WithOptionsClone(opts)
if err != nil { if err != nil {
invalidOpts.Add(err.Subject(name)) invalidOpts.Add(err.Subject(name))
continue continue

View file

@ -8,24 +8,27 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) { func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
}
return BuildMiddlewaresFromYAML(fileContent)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors") b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr) defer b.To(&outErr)
var data map[string][]map[string]any var rawMap map[string][]map[string]any
fileContent, err := os.ReadFile(filePath) err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("read file", err))
return
}
err = yaml.Unmarshal(fileContent, &data)
if err != nil { if err != nil {
b.Add(E.FailWith("toml unmarshal", err)) b.Add(E.FailWith("toml unmarshal", err))
return return
} }
middlewares = make(map[string]*Middleware) middlewares = make(map[string]*Middleware)
for name, defs := range data { for name, defs := range rawMap {
chainErr := E.NewBuilder("errors in middleware chain %s", name) chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs)) chain := make([]*Middleware, 0, len(defs))
for i, def := range defs { for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" { if def["use"] == nil || def["use"].(string) == "" {
@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa
continue continue
} }
delete(def, "use") delete(def, "use")
m, err := base.withOptions(def) m, err := base.WithOptionsClone(def)
if err != nil { if err != nil {
chainErr.Add(err.Subjectf("%s.%d", name, i)) chainErr.Add(err.Subjectf("item%d", i))
continue continue
} }
chain = append(chain, m) chain = append(chain, m)

View file

@ -1,9 +1,21 @@
package middleware package middleware
import ( import (
_ "embed"
"encoding/json"
"testing" "testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
) )
func TestBuild(t *testing.T) { //go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
t.Log(string(data))
} }

View file

@ -2,6 +2,7 @@ package middleware
import ( import (
"fmt" "fmt"
"net/http"
"path" "path"
"strings" "strings"
@ -15,27 +16,27 @@ import (
var middlewares map[string]*Middleware var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) { func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name] middleware, ok = middlewares[strings.ToLower(name)]
return return
} }
// initialize middleware names and label parsers // initialize middleware names and label parsers
func init() { func init() {
middlewares = map[string]*Middleware{ middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded, "setxforwarded": SetXForwarded,
"hide_x_forwarded": HideXForwarded, "hidexforwarded": HideXForwarded,
"redirect_http": RedirectHTTP, "redirecthttp": RedirectHTTP,
"forward_auth": ForwardAuth.m, "forwardauth": ForwardAuth.m,
"modify_response": ModifyResponse.m, "modifyresponse": ModifyResponse.m,
"modify_request": ModifyRequest.m, "modifyrequest": ModifyRequest.m,
"error_page": CustomErrorPage, "errorpage": CustomErrorPage,
"custom_error_page": CustomErrorPage, "customerrorpage": CustomErrorPage,
"real_ip": RealIP.m, "realip": RealIP.m,
"cloudflare_real_ip": CloudflareRealIP.m, "cloudflarerealip": CloudflareRealIP.m,
} }
names := make(map[*Middleware][]string) names := make(map[*Middleware][]string)
for name, m := range middlewares { for name, m := range middlewares {
names[m] = append(names[m], name) names[m] = append(names[m], http.CanonicalHeaderKey(name))
// register middleware name to docker label parsr // register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type // in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil { if m.labelParserMap != nil {
@ -49,6 +50,7 @@ func init() {
m.name = names[0] m.name = names[0]
} }
} }
// TODO: seperate from init() // TODO: seperate from init()
b := E.NewBuilder("failed to load middlewares") b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
@ -57,7 +59,7 @@ func init() {
return return
} }
for _, defFile := range middlewareDefs { for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromYAML(defFile) mws, err := BuildMiddlewaresFromComposeFile(defFile)
for name, m := range mws { for name, m := range mws {
if _, ok := middlewares[name]; ok { if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name)) b.Add(E.Duplicated("middleware", name))

View file

@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
} }
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil) mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
} }
t.Run("set_options", func(t *testing.T) { t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil) mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error()) ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string)) ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View file

@ -2,11 +2,11 @@ package middleware
import ( import (
"net" "net"
"strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker" D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
) )
// https://nginx.org/en/docs/http/ngx_http_realip_module.html // https://nginx.org/en/docs/http/ngx_http_realip_module.html
@ -20,7 +20,7 @@ type realIPOpts struct {
// Header is the name of the header to use for the real client IP // Header is the name of the header to use for the real client IP
Header string Header string
// From is a list of Address / CIDRs to trust // From is a list of Address / CIDRs to trust
From []*net.IPNet From []*types.CIDR
/* /*
If recursive search is disabled, If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by the original client address that matches one of the trusted addresses is replaced by
@ -35,7 +35,7 @@ type realIPOpts struct {
var RealIP = &realIP{ var RealIP = &realIP{
m: &Middleware{ m: &Middleware{
labelParserMap: D.ValueParserMap{ labelParserMap: D.ValueParserMap{
"from": CIDRListParser, "from": D.YamlStringListParser,
"recursive": D.BoolParser, "recursive": D.BoolParser,
}, },
withOptions: NewRealIP, withOptions: NewRealIP,
@ -45,14 +45,7 @@ var RealIP = &realIP{
var realIPOptsDefault = func() *realIPOpts { var realIPOptsDefault = func() *realIPOpts {
return &realIPOpts{ return &realIPOpts{
Header: "X-Real-IP", Header: "X-Real-IP",
From: []*net.IPNet{ From: []*types.CIDR{},
{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)},
{IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)},
},
} }
} }
@ -72,31 +65,6 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
return riWithOpts.m, nil return riWithOpts.m, nil
} }
func CIDRListParser(s string) (any, E.NestedError) {
sl, err := D.YamlStringListParser(s)
if err != nil {
return nil, err
}
b := E.NewBuilder("invalid CIDR(s)")
CIDRs := sl.([]string)
res := make([]*net.IPNet, 0, len(CIDRs))
for _, cidr := range CIDRs {
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
b.Add(E.Invalid("CIDR", cidr))
continue
}
res = append(res, ipnet)
}
return res, b.Build()
}
func (ri *realIP) isInCIDRList(ip net.IP) bool { func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From { for _, CIDR := range ri.From {
if CIDR.Contains(ip) { if CIDR.Contains(ip) {

View file

@ -0,0 +1,58 @@
package middleware
import (
"net"
"testing"
"github.com/yusing/go-proxy/internal/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIP(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
},
"recursive": true,
}
optExpected := &realIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),
},
{
IP: net.ParseIP("192.168.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
{
IP: net.ParseIP("172.16.0.0"),
Mask: net.IPv4Mask(255, 240, 0, 0),
},
},
Recursive: true,
}
t.Run("set_options", func(t *testing.T) {
ri, err := RealIP.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// t.Run("request_headers", func(t *testing.T) {
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
// middlewareOpt: opts,
// })
// ExpectNoError(t, err.Error())
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
// })
}

View file

@ -0,0 +1,41 @@
theGreatPretender:
- use: HideXForwarded
- use: ModifyRequest
setHeaders:
X-Real-IP: 6.6.6.6
- use: ModifyResponse
hideHeaders:
- X-Test3
- X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
header: X-Real-IP
from:
- "127.0.0.0/8"
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP:
- use: ModifyRequest
setHeaders:
CF-Connecting-IP: 127.0.0.1
- use: CloudflareRealIP

34
internal/types/cidr.go Normal file
View file

@ -0,0 +1,34 @@
package types
import (
"net"
"strings"
E "github.com/yusing/go-proxy/internal/error"
)
type CIDR net.IPNet
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
cidr, ok := val.(string)
if !ok {
return nil, E.TypeMismatch[string](val)
}
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, E.Invalid("CIDR", cidr)
}
return (*CIDR)(ipnet), nil
}
func (cidr *CIDR) Contains(ip net.IP) bool {
return (*net.IPNet)(cidr).Contains(ip)
}
func (cidr *CIDR) String() string {
return (*net.IPNet)(cidr).String()
}

View file

@ -12,6 +12,11 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type SerializedObject = map[string]any
type Convertor interface {
ConvertFrom(value any) (any, E.NestedError)
}
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
var i any var i any
@ -89,7 +94,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
} else if field.Anonymous { } else if field.Anonymous {
// If the field is an embedded struct, add its fields to the result // If the field is an embedded struct, add its fields to the result
fieldMap, err := Serialize(value.Field(i).Interface()) fieldMap, err := Serialize(value.Field(i).Interface())
if err.HasError() { if err != nil {
return nil, err return nil, err
} }
for k, v := range fieldMap { for k, v := range fieldMap {
@ -106,98 +111,146 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
return result, nil return result, nil
} }
func Deserialize(src SerializedObject, target any) E.NestedError { // Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
if src == nil || target == nil { // Deserialize ignores case differences between the field names in the SerializedObject and the target.
//
// The target value must be a struct or a map[string]any.
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields.
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
//
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
func Deserialize(src SerializedObject, dst any) E.NestedError {
if src == nil || dst == nil {
return nil return nil
} }
tValue := reflect.ValueOf(target) dstV := reflect.ValueOf(dst)
mapping := make(map[string]string) dstT := dstV.Type()
if tValue.Kind() == reflect.Ptr { if dstV.Kind() == reflect.Ptr {
tValue = tValue.Elem() dstV = dstV.Elem()
dstT = dstV.Type()
} }
// convert data fields to lower no-snake // convert data fields to lower no-snake
// convert target fields to lower no-snake // convert target fields to lower no-snake
// then check if the field of data is in the target // then check if the field of data is in the target
if tValue.Kind() == reflect.Struct { // TODO: use E.Builder to collect errors from all fields
t := reflect.TypeOf(target).Elem()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
snakeCaseField := ToLowerNoSnake(field.Name)
mapping[snakeCaseField] = field.Name
}
} else if tValue.Kind() == reflect.Map && tValue.Type().Key().Kind() == reflect.String {
if tValue.IsNil() {
tValue.Set(reflect.MakeMap(tValue.Type()))
}
for k := range src {
// TODO: type check
tValue.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), reflect.ValueOf(src[k]))
}
return nil
} else {
return E.Unsupported("target type", fmt.Sprintf("%T", target))
}
if dstV.Kind() == reflect.Struct {
mapping := make(map[string]reflect.Value)
for i := 0; i < dstV.NumField(); i++ {
field := dstT.Field(i)
mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i)
}
for k, v := range src { for k, v := range src {
kCleaned := ToLowerNoSnake(k) if field, ok := mapping[ToLowerNoSnake(k)]; ok {
if fieldName, ok := mapping[kCleaned]; ok { err := Convert(reflect.ValueOf(v), field)
prop := tValue.FieldByName(fieldName) if err != nil {
propType := prop.Type() return err.Subject(k)
isPtr := prop.Kind() == reflect.Ptr
if prop.CanSet() {
val := reflect.ValueOf(v)
vType := val.Type()
switch {
case isPtr && vType.ConvertibleTo(propType.Elem()):
ptr := reflect.New(propType.Elem())
ptr.Elem().Set(val.Convert(propType.Elem()))
prop.Set(ptr)
case vType.ConvertibleTo(propType):
prop.Set(val.Convert(propType))
case isPtr:
var vSerialized SerializedObject
vSerialized, ok = v.(SerializedObject)
if !ok {
if vType.ConvertibleTo(reflect.TypeFor[SerializedObject]()) {
vSerialized = val.Convert(reflect.TypeFor[SerializedObject]()).Interface().(SerializedObject)
} else {
return E.Failure(fmt.Sprintf("convert %s (%T) to %s", k, v, reflect.TypeFor[SerializedObject]()))
}
}
propNew := reflect.New(propType.Elem())
err := Deserialize(vSerialized, propNew.Interface())
if err.HasError() {
return E.Failure("set field").With(err).Subject(k)
}
prop.Set(propNew)
default:
obj, ok := val.Interface().(SerializedObject)
if !ok {
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
}
err := Deserialize(obj, prop.Addr().Interface())
if err.HasError() {
return E.Failure("set field").With(err).Subject(k)
}
}
} else {
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
} }
} else { } else {
return E.Unexpected("field", k) return E.Unexpected("field", k)
} }
} }
} else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String {
if dstV.IsNil() {
dstV.Set(reflect.MakeMap(dstT))
}
for k := range src {
tmp := reflect.New(dstT.Elem()).Elem()
err := Convert(reflect.ValueOf(src[k]), tmp)
if err != nil {
return err.Subject(k)
}
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp)
}
return nil
} else {
return E.Unsupported("target type", fmt.Sprintf("%T", dst))
}
return nil
}
// Convert attempts to convert the src to dst.
//
// If src is a map, it is deserialized into dst.
// If src is a slice, each of its elements are converted and stored in dst.
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
//
// If dst is not settable, an error is returned.
// If src cannot be converted to dst, an error is returned.
// If any error occurs during conversion (e.g. deserialization), it is returned.
//
// Returns:
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
srcT := src.Type()
dstVT := dst.Type()
if src.Kind() == reflect.Interface {
src = src.Elem()
srcT = src.Type()
}
if !dst.CanSet() {
return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface()))
}
switch {
case srcT.AssignableTo(dstVT):
dst.Set(src)
case srcT.ConvertibleTo(dstVT):
dst.Set(src.Convert(dstVT))
case srcT.Kind() == reflect.Map:
if dstVT.Kind() != reflect.Map {
return E.TypeError("map", srcT, dstVT)
}
obj, ok := src.Interface().(SerializedObject)
if !ok {
return E.TypeError("map", srcT, dstVT)
}
err := Deserialize(obj, dst.Addr().Interface())
if err != nil {
return err
}
case srcT.Kind() == reflect.Slice:
if dstVT.Kind() != reflect.Slice {
return E.TypeError("slice", srcT, dstVT)
}
newSlice := reflect.MakeSlice(dstVT, 0, src.Len())
i := 0
for _, v := range src.Seq2() {
tmp := reflect.New(dstVT.Elem()).Elem()
err := Convert(v, tmp)
if err != nil {
return err.Subjectf("[%d]", i)
}
newSlice = reflect.Append(newSlice, tmp)
i++
}
dst.Set(newSlice)
default:
// check if Convertor is implemented
if converter, ok := dst.Interface().(Convertor); ok {
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
}
dst.Set(reflect.ValueOf(converted))
return nil
}
return E.TypeError("conversion", srcT, dstVT)
}
return nil return nil
} }
func DeserializeJson(j map[string]string, target any) E.NestedError { func DeserializeJson(j map[string]string, target any) E.NestedError {
data, err := E.Check(json.Marshal(j)) data, err := E.Check(json.Marshal(j))
if err.HasError() { if err != nil {
return err return err
} }
return E.From(json.Unmarshal(data, target)) return E.From(json.Unmarshal(data, target))
@ -206,5 +259,3 @@ func DeserializeJson(j map[string]string, target any) E.NestedError {
func ToLowerNoSnake(s string) string { func ToLowerNoSnake(s string) string {
return strings.ToLower(strings.ReplaceAll(s, "_", "")) return strings.ToLower(strings.ReplaceAll(s, "_", ""))
} }
type SerializedObject = map[string]any

View file

@ -0,0 +1,47 @@
package utils
import (
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
type S = struct {
I int
S string
IS []int
SS []string
MSI map[string]int
MIS map[int]string
}
var testStruct = S{
I: 1,
S: "hello",
IS: []int{1, 2, 3},
SS: []string{"a", "b", "c"},
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
}
var testStructSerialized = map[string]any{
"I": 1,
"S": "hello",
"IS": []int{1, 2, 3},
"SS": []string{"a", "b", "c"},
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
}
func TestSerialize(t *testing.T) {
s, err := Serialize(testStruct)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, s, testStructSerialized)
}
func TestDeserialize(t *testing.T) {
var s S
err := Deserialize(testStructSerialized, &s)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, s, testStruct)
}