fixed serialization and middleware compose

This commit is contained in:
yusing 2024-10-02 01:04:34 +08:00
parent 1bac96dc2a
commit ed887a5cfc
8 changed files with 74 additions and 48 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server"
F "github.com/yusing/go-proxy/internal/utils/functional"
@ -80,8 +81,9 @@ func main() {
prepareDirectory(dir)
}
err := config.Load()
if err != nil {
middleware.LoadComposeFiles()
if err := config.Load(); err != nil {
logrus.Warn(err)
}
cfg := config.GetInstance()
@ -113,11 +115,6 @@ func main() {
}
cfg.StartProxyProviders()
if err.HasError() {
l.Warn(err)
}
cfg.WatchChanges()
onShutdown.Add(docker.CloseAllClients)
@ -132,7 +129,7 @@ func main() {
if autocert != nil {
ctx, cancel := context.WithCancel(context.Background())
if err = autocert.Setup(ctx); err != nil {
if err := autocert.Setup(ctx); err != nil {
l.Fatal(err)
} else {
onShutdown.Add(cancel)

View file

@ -20,6 +20,7 @@
- [Hide X-Forwarded-\*](#hide-x-forwarded-)
- [Set X-Forwarded-\*](#set-x-forwarded-)
- [Forward Authorization header (experimental)](#forward-authorization-header-experimental)
- [Middleware Compose](#middleware-compose)
- [Examples](#examples)
- [Authentik (untested, experimental)](#authentik-untested-experimental)
@ -356,6 +357,14 @@ http:
[🔼Back to top](#table-of-content)
## Middleware Compose
Middleware compose is a way to create reusable middlewares in file(s), just like docker compose.
You may use them with `<middleware_name>@file`
See [example](../internal/net/http/middleware/test_data/middleware_compose.yml)
## Examples
### Authentik (untested, experimental)

View file

@ -18,7 +18,7 @@ const (
ConfigExampleFileName = "config.example.yml"
ConfigPath = ConfigBasePath + "/" + ConfigFileName
MiddlewareDefsBasePath = ConfigBasePath + "/middlewares"
MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
)
const (
@ -41,6 +41,7 @@ var (
ConfigBasePath,
SchemaBasePath,
ErrorPagesBasePath,
MiddlewareComposeBasePath,
}
)

View file

@ -1,6 +1,7 @@
package docker
import (
"strconv"
"strings"
E "github.com/yusing/go-proxy/internal/error"
@ -76,3 +77,11 @@ func BoolParser(value string) (any, E.NestedError) {
return nil, E.Invalid("boolean value", value)
}
}
func IntParser(value string) (any, E.NestedError) {
i, err := strconv.Atoi(value)
if err != nil {
return 0, E.Invalid("integer value", value)
}
return i, nil
}

View file

@ -13,9 +13,10 @@ import (
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
ExpectNoError(t, err.Error())
t.Log(string(data))
_, err = E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
// t.Log(string(data))
// TODO: test
}

View file

@ -33,6 +33,7 @@ func init() {
"customerrorpage": CustomErrorPage,
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
@ -50,10 +51,11 @@ func init() {
m.name = names[0]
}
}
}
// TODO: seperate from init()
func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err)
return

View file

@ -45,14 +45,5 @@ func TestSetRealIP(t *testing.T) {
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// t.Run("request_headers", func(t *testing.T) {
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
// middlewareOpt: opts,
// })
// ExpectNoError(t, err.Error())
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
// })
// TODO test
}

View file

@ -13,7 +13,7 @@ import (
)
type SerializedObject = map[string]any
type Convertor interface {
type Converter interface {
ConvertFrom(value any) (any, E.NestedError)
}
@ -188,7 +188,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
// - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
srcT := src.Type()
dstVT := dst.Type()
dstT := dst.Type()
if src.Kind() == reflect.Interface {
src = src.Elem()
@ -199,31 +199,36 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface()))
}
switch {
case srcT.AssignableTo(dstVT):
dst.Set(src)
case srcT.ConvertibleTo(dstVT):
dst.Set(src.Convert(dstVT))
case srcT.Kind() == reflect.Map:
if dstVT.Kind() != reflect.Map {
return E.TypeError("map", srcT, dstVT)
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
dst.Set(reflect.New(dstT.Elem()))
}
dst = dst.Elem()
dstT = dst.Type()
}
switch {
case srcT.AssignableTo(dstT):
dst.Set(src)
case srcT.ConvertibleTo(dstT):
dst.Set(src.Convert(dstT))
case srcT.Kind() == reflect.Map:
obj, ok := src.Interface().(SerializedObject)
if !ok {
return E.TypeError("map", srcT, dstVT)
return E.TypeMismatch[SerializedObject](src.Interface())
}
err := Deserialize(obj, dst.Addr().Interface())
if err != nil {
return err
}
case srcT.Kind() == reflect.Slice:
if dstVT.Kind() != reflect.Slice {
return E.TypeError("slice", srcT, dstVT)
if dstT.Kind() != reflect.Slice {
return E.TypeError("slice", srcT, dstT)
}
newSlice := reflect.MakeSlice(dstVT, 0, src.Len())
newSlice := reflect.MakeSlice(dstT, 0, src.Len())
i := 0
for _, v := range src.Seq2() {
tmp := reflect.New(dstVT.Elem()).Elem()
tmp := reflect.New(dstT.Elem()).Elem()
err := Convert(v, tmp)
if err != nil {
return err.Subjectf("[%d]", i)
@ -233,16 +238,27 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
}
dst.Set(newSlice)
default:
// check if Convertor is implemented
if converter, ok := dst.Interface().(Convertor); ok {
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
var converter Converter
var ok bool
// check if (*T).Convertor is implemented
if converter, ok = dst.Addr().Interface().(Converter); !ok {
// check if (T).Convertor is implemented
converter, ok = dst.Interface().(Converter)
if !ok {
return E.TypeError("conversion", srcT, dstT)
}
dst.Set(reflect.ValueOf(converted))
return nil
}
return E.TypeError("conversion", srcT, dstVT)
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
}
c := reflect.ValueOf(converted)
if c.Kind() == reflect.Ptr {
c = c.Elem()
}
dst.Set(c)
return nil
}
return nil