diff --git a/internal/api/handler.go b/internal/api/handler.go index f8ea7b0..af76837 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -30,9 +30,9 @@ func NewHandler(cfg *config.Config) http.Handler { mux.HandleFunc("GET", "/v1/list", wrap(cfg, v1.List)) mux.HandleFunc("GET", "/v1/list/{what}", wrap(cfg, v1.List)) mux.HandleFunc("GET", "/v1/file", v1.GetFileContent) - mux.HandleFunc("GET", "/v1/file/{filename}", v1.GetFileContent) - mux.HandleFunc("POST", "/v1/file/{filename}", v1.SetFileContent) - mux.HandleFunc("PUT", "/v1/file/{filename}", v1.SetFileContent) + mux.HandleFunc("GET", "/v1/file/{filename...}", v1.GetFileContent) + mux.HandleFunc("POST", "/v1/file/{filename...}", v1.SetFileContent) + mux.HandleFunc("PUT", "/v1/file/{filename...}", v1.SetFileContent) mux.HandleFunc("GET", "/v1/stats", wrap(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/error_page", error_page.GetHandleFunc()) return mux diff --git a/internal/api/v1/file.go b/internal/api/v1/file.go index ed17b0d..e413ae5 100644 --- a/internal/api/v1/file.go +++ b/internal/api/v1/file.go @@ -5,6 +5,7 @@ import ( "net/http" "os" "path" + "strings" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" @@ -41,7 +42,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { var validateErr E.NestedError if filename == common.ConfigFileName { validateErr = config.Validate(content) - } else { + } else if !strings.HasPrefix(filename, path.Base(common.MiddlewareComposeBasePath)) { validateErr = provider.Validate(content) } diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 1550d50..3317646 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -1,19 +1,20 @@ package v1 import ( - "encoding/json" "net/http" - "os" + "strings" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/net/http/middleware" + "github.com/yusing/go-proxy/internal/utils" ) const ( ListRoutes = "routes" ListConfigFiles = "config_files" + ListMiddlewares = "middlewares" ListMiddlewareTrace = "middleware_trace" ) @@ -28,6 +29,8 @@ func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { listRoutes(cfg, w, r) case ListConfigFiles: listConfigFiles(w, r) + case ListMiddlewares: + listMiddlewares(w, r) case ListMiddlewareTrace: listMiddlewareTrace(w, r) default: @@ -46,34 +49,25 @@ func listRoutes(cfg *config.Config, w http.ResponseWriter, r *http.Request) { } } - if err := U.RespondJson(w, routes); err != nil { - U.HandleErr(w, r, err) - } + U.HandleErr(w, r, U.RespondJson(w, routes)) } func listConfigFiles(w http.ResponseWriter, r *http.Request) { - files, err := os.ReadDir(common.ConfigBasePath) + files, err := utils.ListFiles(common.ConfigBasePath, 1) if err != nil { U.HandleErr(w, r, err) return } - filenames := make([]string, len(files)) - for i, f := range files { - filenames[i] = f.Name() + for i := range files { + files[i] = strings.TrimPrefix(files[i], common.ConfigBasePath+"/") } - resp, err := json.Marshal(filenames) - if err != nil { - U.HandleErr(w, r, err) - return - } - w.Write(resp) + U.HandleErr(w, r, U.RespondJson(w, files)) } func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) { - resp, err := json.Marshal(middleware.GetAllTrace()) - if err != nil { - U.HandleErr(w, r, err) - return - } - w.Write(resp) + U.HandleErr(w, r, U.RespondJson(w, middleware.GetAllTrace())) +} + +func listMiddlewares(w http.ResponseWriter, r *http.Request) { + U.HandleErr(w, r, U.RespondJson(w, middleware.All())) } diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 1a9f58c..957e6cd 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -12,6 +12,9 @@ import ( var Logger = logrus.WithField("module", "api") func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { + if origErr == nil { + return + } err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL) Logger.Error(err) if len(code) > 0 { diff --git a/internal/config/query.go b/internal/config/query.go index 0337699..cef4182 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -38,6 +38,7 @@ func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { obj["provider"] = p.GetName() obj["type"] = string(r.Type()) obj["started"] = r.Started() + obj["raw"] = r.Entry() routes[alias] = obj }) return routes @@ -46,30 +47,17 @@ func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { func (cfg *Config) Statistics() map[string]any { nTotalStreams := 0 nTotalRPs := 0 - providerStats := make(map[string]any) + providerStats := make(map[string]PR.ProviderStats) - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { - if !r.Started() { - return - } - s, ok := providerStats[p.GetName()] - if !ok { - s = make(map[string]int) - } - - stats := s.(map[string]int) - switch r.Type() { - case R.RouteTypeStream: - stats["num_streams"]++ - nTotalStreams++ - case R.RouteTypeReverseProxy: - stats["num_reverse_proxies"]++ - nTotalRPs++ - default: - panic("bug: should not reach here") - } + cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { + providerStats[name] = p.Statistics() }) + for _, stats := range providerStats { + nTotalRPs += stats.NumRPs + nTotalStreams += stats.NumStreams + } + return map[string]any{ "num_total_streams": nTotalStreams, "num_total_reverse_proxies": nTotalRPs, diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index ada30b9..aebec82 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -20,6 +20,10 @@ func Get(name string) (middleware *Middleware, ok bool) { return } +func All() map[string]*Middleware { + return middlewares +} + // initialize middleware names and label parsers func init() { middlewares = map[string]*Middleware{ diff --git a/internal/proxy/provider/provider.go b/internal/proxy/provider/provider.go index a637422..b58cdc8 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/proxy/provider/provider.go @@ -31,8 +31,13 @@ type ( OnEvent(event W.Event, routes R.Routes) EventResult String() string } - ProviderType string - EventResult struct { + ProviderType string + ProviderStats struct { + NumRPs int `json:"num_reverse_proxies"` + NumStreams int `json:"num_streams"` + Type ProviderType `json:"type"` + } + EventResult struct { nRemoved int nAdded int err E.NestedError @@ -164,6 +169,24 @@ func (p *Provider) LoadRoutes() E.NestedError { return E.FailWith("loading routes", err) } +func (p *Provider) Statistics() ProviderStats { + numRPs := 0 + numStreams := 0 + p.routes.RangeAll(func(_ string, r R.Route) { + switch r.Type() { + case R.RouteTypeReverseProxy: + numRPs++ + case R.RouteTypeStream: + numStreams++ + } + }) + return ProviderStats{ + NumRPs: numRPs, + NumStreams: numStreams, + Type: p.t, + } +} + func (p *Provider) watchEvents() { p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) events, errs := p.watcher.Events(p.watcherCtx) diff --git a/internal/utils/fs.go b/internal/utils/fs.go index 09855df..2332220 100644 --- a/internal/utils/fs.go +++ b/internal/utils/fs.go @@ -23,9 +23,7 @@ func ListFiles(dir string, maxDepth int) ([]string, error) { if err != nil { return nil, err } - for _, subEntry := range subEntries { - files = append(files, path.Join(dir, entry.Name(), subEntry)) - } + files = append(files, subEntries...) } else { files = append(files, path.Join(dir, entry.Name())) }