From e951194bee8c89642bbbcb73402aff5d15d5cf7e Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 30 Sep 2024 19:00:27 +0800 Subject: [PATCH] fixed route not being updated on restart, added experimental middleware compose support --- .../net/http/middleware/middleware_builder.go | 104 ++++++++++++++++++ .../middleware/middleware_builder_test.go | 9 ++ internal/net/http/middleware/middlewares.go | 46 ++++---- internal/proxy/provider/docker.go | 10 +- internal/utils/serialization.go | 49 +++++++-- 5 files changed, 184 insertions(+), 34 deletions(-) create mode 100644 internal/net/http/middleware/middleware_builder.go create mode 100644 internal/net/http/middleware/middleware_builder_test.go diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go new file mode 100644 index 0000000..0ad89b8 --- /dev/null +++ b/internal/net/http/middleware/middleware_builder.go @@ -0,0 +1,104 @@ +package middleware + +import ( + "net/http" + "os" + + E "github.com/yusing/go-proxy/internal/error" + "gopkg.in/yaml.v3" +) + +func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) { + b := E.NewBuilder("middlewares compile errors") + defer b.To(&outErr) + + var data map[string][]map[string]any + fileContent, err := os.ReadFile(filePath) + if err != nil { + b.Add(E.FailWith("read file", err)) + return + } + err = yaml.Unmarshal(fileContent, &data) + if err != nil { + b.Add(E.FailWith("toml unmarshal", err)) + return + } + middlewares = make(map[string]*Middleware) + for name, defs := range data { + chainErr := E.NewBuilder("errors in middleware chain %s", name) + chain := make([]*Middleware, 0, len(defs)) + for i, def := range defs { + if def["use"] == nil || def["use"].(string) == "" { + chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i)) + continue + } + baseName := def["use"].(string) + base, ok := Get(baseName) + if !ok { + chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i)) + continue + } + delete(def, "use") + m, err := base.withOptions(def) + if err != nil { + chainErr.Add(err.Subjectf("%s.%d", name, i)) + continue + } + chain = append(chain, m) + } + if chainErr.HasError() { + b.Add(chainErr.Build()) + } else { + name = name + "@file" + middlewares[name] = BuildMiddlewareFromChain(name, chain) + } + } + return +} + +// TODO: check conflict or duplicates +func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { + var ( + befores []BeforeFunc + rewrites []RewriteFunc + modResps []ModifyResponseFunc + ) + for _, m := range chain { + if m.before != nil { + befores = append(befores, m.before) + } + if m.rewrite != nil { + rewrites = append(rewrites, m.rewrite) + } + if m.modifyResponse != nil { + modResps = append(modResps, m.modifyResponse) + } + } + + m := &Middleware{name: name} + if len(befores) > 0 { + m.before = func(next http.Handler, w ResponseWriter, r *Request) { + for _, before := range befores { + before(next, w, r) + } + } + } + if len(rewrites) > 0 { + m.rewrite = func(r *Request) { + for _, rewrite := range rewrites { + rewrite(r) + } + } + } + if len(modResps) > 0 { + m.modifyResponse = func(res *Response) error { + b := E.NewBuilder("errors in middleware %s", name) + for _, mr := range modResps { + b.AddE(mr(res)) + } + return b.Build().Error() + } + } + + return m +} diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/http/middleware/middleware_builder_test.go new file mode 100644 index 0000000..31dc697 --- /dev/null +++ b/internal/net/http/middleware/middleware_builder_test.go @@ -0,0 +1,9 @@ +package middleware + +import ( + "testing" +) + +func TestBuild(t *testing.T) { + +} diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index db70d57..7ef542a 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -2,10 +2,14 @@ package middleware import ( "fmt" + "path" "strings" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" + E "github.com/yusing/go-proxy/internal/error" + U "github.com/yusing/go-proxy/internal/utils" ) var middlewares map[string]*Middleware @@ -46,27 +50,27 @@ func init() { } } // TODO: seperate from init() - // b := E.NewBuilder("failed to load middlewares") - // middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) - // if err != nil { - // logrus.Errorf("failed to list middleware definitions: %s", err) - // return - // } - // for _, defFile := range middlewareDefs { - // mws, err := BuildMiddlewaresFromYAML(defFile) - // for name, m := range mws { - // if _, ok := middlewares[name]; ok { - // b.Add(E.Duplicated("middleware", name)) - // continue - // } - // middlewares[name] = m - // logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) - // } - // b.Add(err.Subject(defFile)) - // } - // if b.HasError() { - // logger.Error(b.Build()) - // } + b := E.NewBuilder("failed to load middlewares") + middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) + if err != nil { + logrus.Errorf("failed to list middleware definitions: %s", err) + return + } + for _, defFile := range middlewareDefs { + mws, err := BuildMiddlewaresFromYAML(defFile) + for name, m := range mws { + if _, ok := middlewares[name]; ok { + b.Add(E.Duplicated("middleware", name)) + continue + } + middlewares[name] = m + logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) + } + b.Add(err.Subject(defFile)) + } + if b.HasError() { + logger.Error(b.Build()) + } } var logger = logrus.WithField("module", "middlewares") diff --git a/internal/proxy/provider/docker.go b/internal/proxy/provider/docker.go index dbe74cd..46fd408 100755 --- a/internal/proxy/provider/docker.go +++ b/internal/proxy/provider/docker.go @@ -12,6 +12,7 @@ import ( M "github.com/yusing/go-proxy/internal/models" R "github.com/yusing/go-proxy/internal/route" W "github.com/yusing/go-proxy/internal/watcher" + "github.com/yusing/go-proxy/internal/watcher/events" ) type DockerProvider struct { @@ -80,10 +81,17 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { func (p *DockerProvider) shouldIgnore(container D.Container) bool { return container.IsExcluded || - !container.IsExplicit && p.ExplicitOnly + !container.IsExplicit && p.ExplicitOnly || + strings.HasSuffix(container.ContainerName, "-old") } func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { + switch event.Action { + case events.ActionContainerStart, events.ActionContainerDie: + break + default: + return + } b := E.NewBuilder("event %s error", event) defer b.To(&res.err) diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 9b54876..9e52583 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -110,24 +110,42 @@ func Deserialize(src SerializedObject, target any) E.NestedError { if src == nil || target == nil { return nil } + + tValue := reflect.ValueOf(target) + mapping := make(map[string]string) + + if tValue.Kind() == reflect.Ptr { + tValue = tValue.Elem() + } + // convert data fields to lower no-snake // convert target fields to lower no-snake // then check if the field of data is in the target - mapping := make(map[string]string) - t := reflect.TypeOf(target).Elem() - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - snakeCaseField := ToLowerNoSnake(field.Name) - mapping[snakeCaseField] = field.Name - } - tValue := reflect.ValueOf(target) - if tValue.IsZero() { - return E.Invalid("value", "nil") + + if tValue.Kind() == reflect.Struct { + t := reflect.TypeOf(target).Elem() + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + snakeCaseField := ToLowerNoSnake(field.Name) + mapping[snakeCaseField] = field.Name + } + } else if tValue.Kind() == reflect.Map && tValue.Type().Key().Kind() == reflect.String { + if tValue.IsNil() { + tValue.Set(reflect.MakeMap(tValue.Type())) + } + for k := range src { + // TODO: type check + tValue.SetMapIndex(reflect.ValueOf(ToLowerNoSnake(k)), reflect.ValueOf(src[k])) + } + return nil + } else { + return E.Unsupported("target type", fmt.Sprintf("%T", target)) } + for k, v := range src { kCleaned := ToLowerNoSnake(k) if fieldName, ok := mapping[kCleaned]; ok { - prop := reflect.ValueOf(target).Elem().FieldByName(fieldName) + prop := tValue.FieldByName(fieldName) propType := prop.Type() isPtr := prop.Kind() == reflect.Ptr if prop.CanSet() { @@ -157,7 +175,14 @@ func Deserialize(src SerializedObject, target any) E.NestedError { } prop.Set(propNew) default: - return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType) + obj, ok := val.Interface().(SerializedObject) + if !ok { + return E.Invalid("conversion", k).Extraf("from %s to %s", vType, propType) + } + err := Deserialize(obj, prop.Addr().Interface()) + if err.HasError() { + return E.Failure("set field").With(err).Subject(k) + } } } else { return E.Unsupported("field", k).Extraf("type %s is not settable", propType)