fixed serialization and middleware compose

This commit is contained in:
yusing 2024-10-02 01:04:34 +08:00
parent 1bac96dc2a
commit ed887a5cfc
8 changed files with 74 additions and 48 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker"
"github.com/yusing/go-proxy/internal/docker/idlewatcher" "github.com/yusing/go-proxy/internal/docker/idlewatcher"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/net/http/middleware"
R "github.com/yusing/go-proxy/internal/route" R "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/server" "github.com/yusing/go-proxy/internal/server"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
@ -80,8 +81,9 @@ func main() {
prepareDirectory(dir) prepareDirectory(dir)
} }
err := config.Load() middleware.LoadComposeFiles()
if err != nil {
if err := config.Load(); err != nil {
logrus.Warn(err) logrus.Warn(err)
} }
cfg := config.GetInstance() cfg := config.GetInstance()
@ -113,11 +115,6 @@ func main() {
} }
cfg.StartProxyProviders() cfg.StartProxyProviders()
if err.HasError() {
l.Warn(err)
}
cfg.WatchChanges() cfg.WatchChanges()
onShutdown.Add(docker.CloseAllClients) onShutdown.Add(docker.CloseAllClients)
@ -132,7 +129,7 @@ func main() {
if autocert != nil { if autocert != nil {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
if err = autocert.Setup(ctx); err != nil { if err := autocert.Setup(ctx); err != nil {
l.Fatal(err) l.Fatal(err)
} else { } else {
onShutdown.Add(cancel) onShutdown.Add(cancel)

View file

@ -20,6 +20,7 @@
- [Hide X-Forwarded-\*](#hide-x-forwarded-) - [Hide X-Forwarded-\*](#hide-x-forwarded-)
- [Set X-Forwarded-\*](#set-x-forwarded-) - [Set X-Forwarded-\*](#set-x-forwarded-)
- [Forward Authorization header (experimental)](#forward-authorization-header-experimental) - [Forward Authorization header (experimental)](#forward-authorization-header-experimental)
- [Middleware Compose](#middleware-compose)
- [Examples](#examples) - [Examples](#examples)
- [Authentik (untested, experimental)](#authentik-untested-experimental) - [Authentik (untested, experimental)](#authentik-untested-experimental)
@ -356,6 +357,14 @@ http:
[🔼Back to top](#table-of-content) [🔼Back to top](#table-of-content)
## Middleware Compose
Middleware compose is a way to create reusable middlewares in file(s), just like docker compose.
You may use them with `<middleware_name>@file`
See [example](../internal/net/http/middleware/test_data/middleware_compose.yml)
## Examples ## Examples
### Authentik (untested, experimental) ### Authentik (untested, experimental)

View file

@ -18,7 +18,7 @@ const (
ConfigExampleFileName = "config.example.yml" ConfigExampleFileName = "config.example.yml"
ConfigPath = ConfigBasePath + "/" + ConfigFileName ConfigPath = ConfigBasePath + "/" + ConfigFileName
MiddlewareDefsBasePath = ConfigBasePath + "/middlewares" MiddlewareComposeBasePath = ConfigBasePath + "/middlewares"
) )
const ( const (
@ -41,6 +41,7 @@ var (
ConfigBasePath, ConfigBasePath,
SchemaBasePath, SchemaBasePath,
ErrorPagesBasePath, ErrorPagesBasePath,
MiddlewareComposeBasePath,
} }
) )

View file

@ -1,6 +1,7 @@
package docker package docker
import ( import (
"strconv"
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -76,3 +77,11 @@ func BoolParser(value string) (any, E.NestedError) {
return nil, E.Invalid("boolean value", value) return nil, E.Invalid("boolean value", value)
} }
} }
func IntParser(value string) (any, E.NestedError) {
i, err := strconv.Atoi(value)
if err != nil {
return 0, E.Invalid("integer value", value)
}
return i, nil
}

View file

@ -13,9 +13,10 @@ import (
var testMiddlewareCompose []byte var testMiddlewareCompose []byte
func TestBuild(t *testing.T) { func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error()) ExpectNoError(t, err.Error())
t.Log(string(data)) _, err = E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
// t.Log(string(data))
// TODO: test
} }

View file

@ -33,6 +33,7 @@ func init() {
"customerrorpage": CustomErrorPage, "customerrorpage": CustomErrorPage,
"realip": RealIP.m, "realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m, "cloudflarerealip": CloudflareRealIP.m,
"cidrwhitelist": CIDRWhiteList.m,
} }
names := make(map[*Middleware][]string) names := make(map[*Middleware][]string)
for name, m := range middlewares { for name, m := range middlewares {
@ -50,10 +51,11 @@ func init() {
m.name = names[0] m.name = names[0]
} }
} }
}
// TODO: seperate from init() func LoadComposeFiles() {
b := E.NewBuilder("failed to load middlewares") b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0)
if err != nil { if err != nil {
logrus.Errorf("failed to list middleware definitions: %s", err) logrus.Errorf("failed to list middleware definitions: %s", err)
return return

View file

@ -45,14 +45,5 @@ func TestSetRealIP(t *testing.T) {
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) // ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected) ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
}) })
// TODO test
// t.Run("request_headers", func(t *testing.T) {
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
// middlewareOpt: opts,
// })
// ExpectNoError(t, err.Error())
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
// })
} }

View file

@ -13,7 +13,7 @@ import (
) )
type SerializedObject = map[string]any type SerializedObject = map[string]any
type Convertor interface { type Converter interface {
ConvertFrom(value any) (any, E.NestedError) ConvertFrom(value any) (any, E.NestedError)
} }
@ -188,7 +188,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError {
// - error: the error occurred during conversion, or nil if no error occurred. // - error: the error occurred during conversion, or nil if no error occurred.
func Convert(src reflect.Value, dst reflect.Value) E.NestedError { func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
srcT := src.Type() srcT := src.Type()
dstVT := dst.Type() dstT := dst.Type()
if src.Kind() == reflect.Interface { if src.Kind() == reflect.Interface {
src = src.Elem() src = src.Elem()
@ -199,31 +199,36 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface())) return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface()))
} }
switch { if dst.Kind() == reflect.Pointer {
case srcT.AssignableTo(dstVT): if dst.IsNil() {
dst.Set(src) dst.Set(reflect.New(dstT.Elem()))
case srcT.ConvertibleTo(dstVT):
dst.Set(src.Convert(dstVT))
case srcT.Kind() == reflect.Map:
if dstVT.Kind() != reflect.Map {
return E.TypeError("map", srcT, dstVT)
} }
dst = dst.Elem()
dstT = dst.Type()
}
switch {
case srcT.AssignableTo(dstT):
dst.Set(src)
case srcT.ConvertibleTo(dstT):
dst.Set(src.Convert(dstT))
case srcT.Kind() == reflect.Map:
obj, ok := src.Interface().(SerializedObject) obj, ok := src.Interface().(SerializedObject)
if !ok { if !ok {
return E.TypeError("map", srcT, dstVT) return E.TypeMismatch[SerializedObject](src.Interface())
} }
err := Deserialize(obj, dst.Addr().Interface()) err := Deserialize(obj, dst.Addr().Interface())
if err != nil { if err != nil {
return err return err
} }
case srcT.Kind() == reflect.Slice: case srcT.Kind() == reflect.Slice:
if dstVT.Kind() != reflect.Slice { if dstT.Kind() != reflect.Slice {
return E.TypeError("slice", srcT, dstVT) return E.TypeError("slice", srcT, dstT)
} }
newSlice := reflect.MakeSlice(dstVT, 0, src.Len()) newSlice := reflect.MakeSlice(dstT, 0, src.Len())
i := 0 i := 0
for _, v := range src.Seq2() { for _, v := range src.Seq2() {
tmp := reflect.New(dstVT.Elem()).Elem() tmp := reflect.New(dstT.Elem()).Elem()
err := Convert(v, tmp) err := Convert(v, tmp)
if err != nil { if err != nil {
return err.Subjectf("[%d]", i) return err.Subjectf("[%d]", i)
@ -233,16 +238,27 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError {
} }
dst.Set(newSlice) dst.Set(newSlice)
default: default:
// check if Convertor is implemented var converter Converter
if converter, ok := dst.Interface().(Convertor); ok { var ok bool
converted, err := converter.ConvertFrom(src.Interface()) // check if (*T).Convertor is implemented
if err != nil { if converter, ok = dst.Addr().Interface().(Converter); !ok {
return err // check if (T).Convertor is implemented
converter, ok = dst.Interface().(Converter)
if !ok {
return E.TypeError("conversion", srcT, dstT)
} }
dst.Set(reflect.ValueOf(converted))
return nil
} }
return E.TypeError("conversion", srcT, dstVT)
converted, err := converter.ConvertFrom(src.Interface())
if err != nil {
return err
}
c := reflect.ValueOf(converted)
if c.Kind() == reflect.Ptr {
c = c.Elem()
}
dst.Set(c)
return nil
} }
return nil return nil