mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
implement middleware compose
This commit is contained in:
parent
f5a36f94bb
commit
44cfd65f6c
17 changed files with 392 additions and 152 deletions
|
@ -11,6 +11,8 @@
|
|||
|
||||
A lightweight, easy-to-use, and [performant](docs/benchmark_result.md) reverse proxy with a web UI.
|
||||
|
||||
_Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_
|
||||
|
||||
## Table of content
|
||||
|
||||
<!-- TOC -->
|
||||
|
|
|
@ -2,6 +2,7 @@ package error
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -24,7 +25,6 @@ func NewBuilder(format string, args ...any) Builder {
|
|||
func (b Builder) Add(err NestedError) Builder {
|
||||
if err != nil {
|
||||
b.Lock()
|
||||
// TODO: if err severity is higher than b.severity, update b.severity
|
||||
b.errors = append(b.errors, err)
|
||||
b.Unlock()
|
||||
}
|
||||
|
@ -49,6 +49,8 @@ func (b Builder) Addf(format string, args ...any) Builder {
|
|||
func (b Builder) Build() NestedError {
|
||||
if len(b.errors) == 0 {
|
||||
return nil
|
||||
} else if len(b.errors) == 1 && !strings.ContainsRune(b.message, ' ') {
|
||||
return b.errors[0].Subject(b.message)
|
||||
}
|
||||
return Join(b.message, b.errors...)
|
||||
}
|
||||
|
|
|
@ -166,6 +166,8 @@ func (ne NestedError) Subject(s any) NestedError {
|
|||
}
|
||||
if ne.subject == "" {
|
||||
ne.subject = subject
|
||||
} else if !strings.ContainsRune(subject, ' ') || strings.ContainsRune(ne.subject, '.') {
|
||||
ne.subject = fmt.Sprintf("%s.%s", subject, ne.subject)
|
||||
} else {
|
||||
ne.subject = fmt.Sprintf("%s > %s", subject, ne.subject)
|
||||
}
|
||||
|
|
|
@ -6,15 +6,16 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrDuplicated = stderrors.New("duplicated")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrTypeError = stderrors.New("type error")
|
||||
ErrFailure = stderrors.New("failed")
|
||||
ErrInvalid = stderrors.New("invalid")
|
||||
ErrUnsupported = stderrors.New("unsupported")
|
||||
ErrUnexpected = stderrors.New("unexpected")
|
||||
ErrNotExists = stderrors.New("does not exist")
|
||||
ErrMissing = stderrors.New("missing")
|
||||
ErrDuplicated = stderrors.New("duplicated")
|
||||
ErrOutOfRange = stderrors.New("out of range")
|
||||
ErrTypeError = stderrors.New("type error")
|
||||
ErrTypeMismatch = stderrors.New("type mismatch")
|
||||
)
|
||||
|
||||
const fmtSubjectWhat = "%w %v: %q"
|
||||
|
@ -63,6 +64,14 @@ func OutOfRange(subject any, value any) NestedError {
|
|||
return errorf("%v %w: %v", subject, ErrOutOfRange, value)
|
||||
}
|
||||
|
||||
func TypeError(subject any, from, to reflect.Value) NestedError {
|
||||
return errorf("%v %w: %T -> %T", subject, ErrTypeError, from.Interface(), to.Interface())
|
||||
func TypeError(subject any, from, to reflect.Type) NestedError {
|
||||
return errorf("%v %w: %s -> %s\n", subject, ErrTypeError, from, to)
|
||||
}
|
||||
|
||||
func TypeError2(subject any, from, to reflect.Value) NestedError {
|
||||
return TypeError(subject, from.Type(), to.Type())
|
||||
}
|
||||
|
||||
func TypeMismatch[Expect any](value any) NestedError {
|
||||
return errorf("%w: expect %s got %T", ErrTypeMismatch, reflect.TypeFor[Expect](), value)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
|||
return cri.m, nil
|
||||
}
|
||||
|
||||
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
|
||||
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
|
||||
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
|
||||
return
|
||||
}
|
||||
|
@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
|
|||
}
|
||||
|
||||
if common.IsTest {
|
||||
cfCIDRs = []*net.IPNet{
|
||||
cfCIDRs = []*types.CIDR{
|
||||
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
|
||||
}
|
||||
} else {
|
||||
cfCIDRs = make([]*net.IPNet, 0, 30)
|
||||
cfCIDRs = make([]*types.CIDR, 0, 30)
|
||||
err := errors.Join(
|
||||
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
|
||||
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
|
||||
|
@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
|
|||
return
|
||||
}
|
||||
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
|
||||
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
|
||||
resp, err := http.Get(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
|
||||
} else {
|
||||
cfCIDRs = append(cfCIDRs, cidr)
|
||||
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
|
@ -53,7 +54,14 @@ func (m *Middleware) String() string {
|
|||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
|
||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||
return json.MarshalIndent(map[string]any{
|
||||
"name": m.name,
|
||||
"options": m.impl,
|
||||
}, "", " ")
|
||||
}
|
||||
|
||||
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
if len(optsRaw) != 0 && m.withOptions != nil {
|
||||
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
|
||||
return nil, err
|
||||
|
@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
|||
continue
|
||||
}
|
||||
|
||||
m, err := m.WithOptionsClone(opts, rp)
|
||||
m, err := m.WithOptionsClone(opts)
|
||||
if err != nil {
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
|
|
|
@ -8,24 +8,27 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) {
|
||||
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, E.FailWith("read middleware compose file", err)
|
||||
}
|
||||
return BuildMiddlewaresFromYAML(fileContent)
|
||||
}
|
||||
|
||||
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
|
||||
b := E.NewBuilder("middlewares compile errors")
|
||||
defer b.To(&outErr)
|
||||
|
||||
var data map[string][]map[string]any
|
||||
fileContent, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("read file", err))
|
||||
return
|
||||
}
|
||||
err = yaml.Unmarshal(fileContent, &data)
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("toml unmarshal", err))
|
||||
return
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
for name, defs := range data {
|
||||
chainErr := E.NewBuilder("errors in middleware chain %s", name)
|
||||
for name, defs := range rawMap {
|
||||
chainErr := E.NewBuilder(name)
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"].(string) == "" {
|
||||
|
@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa
|
|||
continue
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.withOptions(def)
|
||||
m, err := base.WithOptionsClone(def)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("%s.%d", name, i))
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
continue
|
||||
}
|
||||
chain = append(chain, m)
|
||||
|
|
|
@ -1,9 +1,21 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
//go:embed test_data/middleware_compose.yml
|
||||
var testMiddlewareCompose []byte
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
|
||||
// ExpectNoError(t, err.Error())
|
||||
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
|
||||
ExpectNoError(t, err.Error())
|
||||
t.Log(string(data))
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package middleware
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
|
@ -15,27 +16,27 @@ import (
|
|||
var middlewares map[string]*Middleware
|
||||
|
||||
func Get(name string) (middleware *Middleware, ok bool) {
|
||||
middleware, ok = middlewares[name]
|
||||
middleware, ok = middlewares[strings.ToLower(name)]
|
||||
return
|
||||
}
|
||||
|
||||
// initialize middleware names and label parsers
|
||||
func init() {
|
||||
middlewares = map[string]*Middleware{
|
||||
"set_x_forwarded": SetXForwarded,
|
||||
"hide_x_forwarded": HideXForwarded,
|
||||
"redirect_http": RedirectHTTP,
|
||||
"forward_auth": ForwardAuth.m,
|
||||
"modify_response": ModifyResponse.m,
|
||||
"modify_request": ModifyRequest.m,
|
||||
"error_page": CustomErrorPage,
|
||||
"custom_error_page": CustomErrorPage,
|
||||
"real_ip": RealIP.m,
|
||||
"cloudflare_real_ip": CloudflareRealIP.m,
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"modifyresponse": ModifyResponse.m,
|
||||
"modifyrequest": ModifyRequest.m,
|
||||
"errorpage": CustomErrorPage,
|
||||
"customerrorpage": CustomErrorPage,
|
||||
"realip": RealIP.m,
|
||||
"cloudflarerealip": CloudflareRealIP.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
names[m] = append(names[m], name)
|
||||
names[m] = append(names[m], http.CanonicalHeaderKey(name))
|
||||
// register middleware name to docker label parsr
|
||||
// in order to parse middleware_name.option=value into correct type
|
||||
if m.labelParserMap != nil {
|
||||
|
@ -49,6 +50,7 @@ func init() {
|
|||
m.name = names[0]
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: seperate from init()
|
||||
b := E.NewBuilder("failed to load middlewares")
|
||||
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
|
||||
|
@ -57,7 +59,7 @@ func init() {
|
|||
return
|
||||
}
|
||||
for _, defFile := range middlewareDefs {
|
||||
mws, err := BuildMiddlewaresFromYAML(defFile)
|
||||
mws, err := BuildMiddlewaresFromComposeFile(defFile)
|
||||
for name, m := range mws {
|
||||
if _, ok := middlewares[name]; ok {
|
||||
b.Add(E.Duplicated("middleware", name))
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
|
||||
mr, err := ModifyRequest.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
|
||||
mr, err := ModifyResponse.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
|
||||
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))
|
||||
|
|
|
@ -2,11 +2,11 @@ package middleware
|
|||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
@ -20,7 +20,7 @@ type realIPOpts struct {
|
|||
// Header is the name of the header to use for the real client IP
|
||||
Header string
|
||||
// From is a list of Address / CIDRs to trust
|
||||
From []*net.IPNet
|
||||
From []*types.CIDR
|
||||
/*
|
||||
If recursive search is disabled,
|
||||
the original client address that matches one of the trusted addresses is replaced by
|
||||
|
@ -35,7 +35,7 @@ type realIPOpts struct {
|
|||
var RealIP = &realIP{
|
||||
m: &Middleware{
|
||||
labelParserMap: D.ValueParserMap{
|
||||
"from": CIDRListParser,
|
||||
"from": D.YamlStringListParser,
|
||||
"recursive": D.BoolParser,
|
||||
},
|
||||
withOptions: NewRealIP,
|
||||
|
@ -45,14 +45,7 @@ var RealIP = &realIP{
|
|||
var realIPOptsDefault = func() *realIPOpts {
|
||||
return &realIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)},
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
|
||||
{IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)},
|
||||
{IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)},
|
||||
},
|
||||
From: []*types.CIDR{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,31 +65,6 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
|||
return riWithOpts.m, nil
|
||||
}
|
||||
|
||||
func CIDRListParser(s string) (any, E.NestedError) {
|
||||
sl, err := D.YamlStringListParser(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := E.NewBuilder("invalid CIDR(s)")
|
||||
|
||||
CIDRs := sl.([]string)
|
||||
res := make([]*net.IPNet, 0, len(CIDRs))
|
||||
|
||||
for _, cidr := range CIDRs {
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
b.Add(E.Invalid("CIDR", cidr))
|
||||
continue
|
||||
}
|
||||
res = append(res, ipnet)
|
||||
}
|
||||
return res, b.Build()
|
||||
}
|
||||
|
||||
func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||
for _, CIDR := range ri.From {
|
||||
if CIDR.Contains(ip) {
|
||||
|
|
58
internal/net/http/middleware/real_ip_test.go
Normal file
58
internal/net/http/middleware/real_ip_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": "X-Real-IP",
|
||||
"from": []string{
|
||||
"127.0.0.0/8",
|
||||
"192.168.0.0/16",
|
||||
"172.16.0.0/12",
|
||||
},
|
||||
"recursive": true,
|
||||
}
|
||||
optExpected := &realIPOpts{
|
||||
Header: "X-Real-IP",
|
||||
From: []*types.CIDR{
|
||||
{
|
||||
IP: net.ParseIP("127.0.0.0"),
|
||||
Mask: net.IPv4Mask(255, 0, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("192.168.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
{
|
||||
IP: net.ParseIP("172.16.0.0"),
|
||||
Mask: net.IPv4Mask(255, 240, 0, 0),
|
||||
},
|
||||
},
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
ri, err := RealIP.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
|
||||
})
|
||||
|
||||
// t.Run("request_headers", func(t *testing.T) {
|
||||
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
|
||||
// middlewareOpt: opts,
|
||||
// })
|
||||
// ExpectNoError(t, err.Error())
|
||||
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
|
||||
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
|
||||
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
|
||||
// })
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
theGreatPretender:
|
||||
- use: HideXForwarded
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 6.6.6.6
|
||||
- use: ModifyResponse
|
||||
hideHeaders:
|
||||
- X-Test3
|
||||
- X-Test4
|
||||
|
||||
notAuthenticAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
addAuthCookiesToResponse:
|
||||
- session_id
|
||||
- user_id
|
||||
authResponseHeaders:
|
||||
- X-Auth-SessionID
|
||||
- X-Auth-UserID
|
||||
- use: CustomErrorPage
|
||||
|
||||
realIPAuthentik:
|
||||
- use: RedirectHTTP
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- "127.0.0.0/8"
|
||||
- "192.168.0.0/16"
|
||||
- "172.16.0.0/12"
|
||||
recursive: true
|
||||
- use: ForwardAuth
|
||||
address: https://authentik.company
|
||||
trustForwardHeader: true
|
||||
|
||||
testFakeRealIP:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
CF-Connecting-IP: 127.0.0.1
|
||||
- use: CloudflareRealIP
|
34
internal/types/cidr.go
Normal file
34
internal/types/cidr.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type CIDR net.IPNet
|
||||
|
||||
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
|
||||
cidr, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, E.TypeMismatch[string](val)
|
||||
}
|
||||
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, E.Invalid("CIDR", cidr)
|
||||
}
|
||||
return (*CIDR)(ipnet), nil
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Contains(ip net.IP) bool {
|
||||
return (*net.IPNet)(cidr).Contains(ip)
|
||||
}
|
||||
|
||||
func (cidr *CIDR) String() string {
|
||||
return (*net.IPNet)(cidr).String()
|
||||
}
|
|
@ -12,6 +12,11 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type SerializedObject = map[string]any
|
||||
type Convertor interface {
|
||||
ConvertFrom(value any) (any, E.NestedError)
|
||||
}
|
||||
|
||||
func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError {
|
||||
var i any
|
||||
|
||||
|
@ -89,7 +94,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
|
|||
} else if field.Anonymous {
|
||||
// If the field is an embedded struct, add its fields to the result
|
||||
fieldMap, err := Serialize(value.Field(i).Interface())
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range fieldMap {
|
||||
|
@ -106,90 +111,138 @@ func Serialize(data any) (SerializedObject, E.NestedError) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func Deserialize(src SerializedObject, target any) E.NestedError {
|
||||
if src == nil || target == nil {
|
||||
// Deserialize takes a SerializedObject and a target value, and assigns the values in the SerializedObject to the target value.
|
||||
// Deserialize ignores case differences between the field names in the SerializedObject and the target.
|
||||
//
|
||||
// The target value must be a struct or a map[string]any.
|
||||
// If the target value is a struct, the SerializedObject will be deserialized into the struct fields.
|
||||
// If the target value is a map[string]any, the SerializedObject will be deserialized into the map.
|
||||
//
|
||||
// The function returns an error if the target value is not a struct or a map[string]any, or if there is an error during deserialization.
|
||||
func Deserialize(src SerializedObject, dst any) E.NestedError {
|
||||
if src == nil || dst == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tValue := reflect.ValueOf(target)
|
||||
mapping := make(map[string]string)
|
||||
dstV := reflect.ValueOf(dst)
|
||||
dstT := dstV.Type()
|
||||
|
||||
if tValue.Kind() == reflect.Ptr {
|
||||
tValue = tValue.Elem()
|
||||
if dstV.Kind() == reflect.Ptr {
|
||||
dstV = dstV.Elem()
|
||||
dstT = dstV.Type()
|
||||
}
|
||||
|
||||
// convert data fields to lower no-snake
|
||||
// convert target fields to lower no-snake
|
||||
// then check if the field of data is in the target
|
||||
|
||||
if tValue.Kind() == reflect.Struct {
|
||||
t := reflect.TypeOf(target).Elem()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
snakeCaseField := ToLowerNoSnake(field.Name)
|
||||
mapping[snakeCaseField] = field.Name
|
||||
// TODO: use E.Builder to collect errors from all fields
|
||||
|
||||
if dstV.Kind() == reflect.Struct {
|
||||
mapping := make(map[string]reflect.Value)
|
||||
for i := 0; i < dstV.NumField(); i++ {
|
||||
field := dstT.Field(i)
|
||||
mapping[ToLowerNoSnake(field.Name)] = dstV.Field(i)
|
||||
}
|
||||
} else if tValue.Kind() == reflect.Map && tValue.Type().Key().Kind() == reflect.String {
|
||||
if tValue.IsNil() {
|
||||
tValue.Set(reflect.MakeMap(tValue.Type()))
|
||||
for k, v := range src {
|
||||
if field, ok := mapping[ToLowerNoSnake(k)]; ok {
|
||||
err := Convert(reflect.ValueOf(v), field)
|
||||
if err != nil {
|
||||
return err.Subject(k)
|
||||
}
|
||||
} else {
|
||||
return E.Unexpected("field", k)
|
||||
}
|
||||
}
|
||||
} else if dstV.Kind() == reflect.Map && dstT.Key().Kind() == reflect.String {
|
||||
if dstV.IsNil() {
|
||||
dstV.Set(reflect.MakeMap(dstT))
|
||||
}
|
||||
for k := range src {
|
||||
// TODO: type check
|
||||
tValue.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), reflect.ValueOf(src[k]))
|
||||
tmp := reflect.New(dstT.Elem()).Elem()
|
||||
err := Convert(reflect.ValueOf(src[k]), tmp)
|
||||
if err != nil {
|
||||
return err.Subject(k)
|
||||
}
|
||||
dstV.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), tmp)
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return E.Unsupported("target type", fmt.Sprintf("%T", target))
|
||||
return E.Unsupported("target type", fmt.Sprintf("%T", dst))
|
||||
}
|
||||
|
||||
for k, v := range src {
|
||||
kCleaned := ToLowerNoSnake(k)
|
||||
if fieldName, ok := mapping[kCleaned]; ok {
|
||||
prop := tValue.FieldByName(fieldName)
|
||||
propType := prop.Type()
|
||||
isPtr := prop.Kind() == reflect.Ptr
|
||||
if prop.CanSet() {
|
||||
val := reflect.ValueOf(v)
|
||||
vType := val.Type()
|
||||
switch {
|
||||
case isPtr && vType.ConvertibleTo(propType.Elem()):
|
||||
ptr := reflect.New(propType.Elem())
|
||||
ptr.Elem().Set(val.Convert(propType.Elem()))
|
||||
prop.Set(ptr)
|
||||
case vType.ConvertibleTo(propType):
|
||||
prop.Set(val.Convert(propType))
|
||||
case isPtr:
|
||||
var vSerialized SerializedObject
|
||||
vSerialized, ok = v.(SerializedObject)
|
||||
if !ok {
|
||||
if vType.ConvertibleTo(reflect.TypeFor[SerializedObject]()) {
|
||||
vSerialized = val.Convert(reflect.TypeFor[SerializedObject]()).Interface().(SerializedObject)
|
||||
} else {
|
||||
return E.Failure(fmt.Sprintf("convert %s (%T) to %s", k, v, reflect.TypeFor[SerializedObject]()))
|
||||
}
|
||||
}
|
||||
propNew := reflect.New(propType.Elem())
|
||||
err := Deserialize(vSerialized, propNew.Interface())
|
||||
if err.HasError() {
|
||||
return E.Failure("set field").With(err).Subject(k)
|
||||
}
|
||||
prop.Set(propNew)
|
||||
default:
|
||||
obj, ok := val.Interface().(SerializedObject)
|
||||
if !ok {
|
||||
return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType)
|
||||
}
|
||||
err := Deserialize(obj, prop.Addr().Interface())
|
||||
if err.HasError() {
|
||||
return E.Failure("set field").With(err).Subject(k)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return E.Unsupported("field", k).Extraf("type %s is not settable", propType)
|
||||
}
|
||||
} else {
|
||||
return E.Unexpected("field", k)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert attempts to convert the src to dst.
|
||||
//
|
||||
// If src is a map, it is deserialized into dst.
|
||||
// If src is a slice, each of its elements are converted and stored in dst.
|
||||
// For any other type, it is converted using the reflect.Value.Convert function (if possible).
|
||||
//
|
||||
// If dst is not settable, an error is returned.
|
||||
// If src cannot be converted to dst, an error is returned.
|
||||
// If any error occurs during conversion (e.g. deserialization), it is returned.
|
||||
//
|
||||
// Returns:
|
||||
// - error: the error occurred during conversion, or nil if no error occurred.
|
||||
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
|
||||
srcT := src.Type()
|
||||
dstVT := dst.Type()
|
||||
|
||||
if src.Kind() == reflect.Interface {
|
||||
src = src.Elem()
|
||||
srcT = src.Type()
|
||||
}
|
||||
|
||||
if !dst.CanSet() {
|
||||
return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface()))
|
||||
}
|
||||
|
||||
switch {
|
||||
case srcT.AssignableTo(dstVT):
|
||||
dst.Set(src)
|
||||
case srcT.ConvertibleTo(dstVT):
|
||||
dst.Set(src.Convert(dstVT))
|
||||
case srcT.Kind() == reflect.Map:
|
||||
if dstVT.Kind() != reflect.Map {
|
||||
return E.TypeError("map", srcT, dstVT)
|
||||
}
|
||||
obj, ok := src.Interface().(SerializedObject)
|
||||
if !ok {
|
||||
return E.TypeError("map", srcT, dstVT)
|
||||
}
|
||||
err := Deserialize(obj, dst.Addr().Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case srcT.Kind() == reflect.Slice:
|
||||
if dstVT.Kind() != reflect.Slice {
|
||||
return E.TypeError("slice", srcT, dstVT)
|
||||
}
|
||||
newSlice := reflect.MakeSlice(dstVT, 0, src.Len())
|
||||
i := 0
|
||||
for _, v := range src.Seq2() {
|
||||
tmp := reflect.New(dstVT.Elem()).Elem()
|
||||
err := Convert(v, tmp)
|
||||
if err != nil {
|
||||
return err.Subjectf("[%d]", i)
|
||||
}
|
||||
newSlice = reflect.Append(newSlice, tmp)
|
||||
i++
|
||||
}
|
||||
dst.Set(newSlice)
|
||||
default:
|
||||
// check if Convertor is implemented
|
||||
if converter, ok := dst.Interface().(Convertor); ok {
|
||||
converted, err := converter.ConvertFrom(src.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Set(reflect.ValueOf(converted))
|
||||
return nil
|
||||
}
|
||||
return E.TypeError("conversion", srcT, dstVT)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -197,7 +250,7 @@ func Deserialize(src SerializedObject, target any) E.NestedError {
|
|||
|
||||
func DeserializeJson(j map[string]string, target any) E.NestedError {
|
||||
data, err := E.Check(json.Marshal(j))
|
||||
if err.HasError() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return E.From(json.Unmarshal(data, target))
|
||||
|
@ -206,5 +259,3 @@ func DeserializeJson(j map[string]string, target any) E.NestedError {
|
|||
func ToLowerNoSnake(s string) string {
|
||||
return strings.ToLower(strings.ReplaceAll(s, "_", ""))
|
||||
}
|
||||
|
||||
type SerializedObject = map[string]any
|
||||
|
|
47
internal/utils/serialization_test.go
Normal file
47
internal/utils/serialization_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
type S = struct {
|
||||
I int
|
||||
S string
|
||||
IS []int
|
||||
SS []string
|
||||
MSI map[string]int
|
||||
MIS map[int]string
|
||||
}
|
||||
|
||||
var testStruct = S{
|
||||
I: 1,
|
||||
S: "hello",
|
||||
IS: []int{1, 2, 3},
|
||||
SS: []string{"a", "b", "c"},
|
||||
MSI: map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
MIS: map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
|
||||
var testStructSerialized = map[string]any{
|
||||
"I": 1,
|
||||
"S": "hello",
|
||||
"IS": []int{1, 2, 3},
|
||||
"SS": []string{"a", "b", "c"},
|
||||
"MSI": map[string]int{"a": 1, "b": 2, "c": 3},
|
||||
"MIS": map[int]string{1: "a", 2: "b", 3: "c"},
|
||||
}
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
s, err := Serialize(testStruct)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, s, testStructSerialized)
|
||||
}
|
||||
|
||||
func TestDeserialize(t *testing.T) {
|
||||
var s S
|
||||
err := Deserialize(testStructSerialized, &s)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectDeepEqual(t, s, testStruct)
|
||||
}
|
Loading…
Add table
Reference in a new issue