package middleware import ( "fmt" "os" "path" "sort" "github.com/goccy/go-yaml" "github.com/yusing/go-proxy/internal/gperr" ) var ErrMissingMiddlewareUse = gperr.New("missing middleware 'use' field") func BuildMiddlewaresFromComposeFile(filePath string, eb *gperr.Builder) map[string]*Middleware { fileContent, err := os.ReadFile(filePath) if err != nil { eb.Add(err) return nil } return BuildMiddlewaresFromYAML(path.Base(filePath), fileContent, eb) } func BuildMiddlewaresFromYAML(source string, data []byte, eb *gperr.Builder) map[string]*Middleware { var rawMap map[string][]map[string]any err := yaml.Unmarshal(data, &rawMap) if err != nil { eb.Add(err) return nil } middlewares := make(map[string]*Middleware) for name, defs := range rawMap { chain, err := BuildMiddlewareFromChainRaw(name, defs) if err != nil { eb.Add(err.Subject(source)) } else { middlewares[name+"@file"] = chain } } return middlewares } func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, gperr.Error) { middlewares := make([]*Middleware, 0, len(middlewaresMap)) errs := gperr.NewBuilder() invalidOpts := gperr.NewBuilder() for name, opts := range middlewaresMap { m, err := Get(name) if err != nil { errs.Add(err) continue } m, err = m.New(opts) if err != nil { invalidOpts.Add(err.Subject("middlewares." + name)) continue } middlewares = append(middlewares, m) } if invalidOpts.HasError() { errs.Add(invalidOpts.Error()) } sort.Sort(ByPriority(middlewares)) return middlewares, errs.Error() } func BuildMiddlewareFromMap(name string, middlewaresMap map[string]OptionsRaw) (*Middleware, gperr.Error) { compiled, err := compileMiddlewares(middlewaresMap) if err != nil { return nil, err } return NewMiddlewareChain(name, compiled), nil } // TODO: check conflict or duplicates. func BuildMiddlewareFromChainRaw(name string, defs []map[string]any) (*Middleware, gperr.Error) { chainErr := gperr.NewBuilder("") chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { if def["use"] == nil || def["use"] == "" { chainErr.Add(ErrMissingMiddlewareUse.Subjectf("%s[%d]", name, i)) continue } baseName := def["use"].(string) base, err := Get(baseName) if err != nil { chainErr.Add(err.Subjectf("%s[%d]", name, i)) continue } delete(def, "use") m, err := base.New(def) if err != nil { chainErr.Add(err.Subjectf("%s[%d]", name, i)) continue } m.name = fmt.Sprintf("%s[%d]", name, i) chain = append(chain, m) } if chainErr.HasError() { return nil, chainErr.Error() } return NewMiddlewareChain(name, chain), nil }