From 4f94a0f08a42572489f2787790ef926d922d5d52 Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 24 Feb 2025 03:28:23 +0800 Subject: [PATCH] improved add agent mechanism --- internal/api/handler.go | 2 +- internal/api/v1/new_agent.go | 4 ++-- internal/config/agent_pool.go | 19 +++++++++++++------ internal/config/types/config.go | 2 +- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/internal/api/handler.go b/internal/api/handler.go index 4dff2dc..a16e19b 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -84,7 +84,7 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true) mux.HandleFunc("GET", "/v1/agents", v1.ListAgents, true) mux.HandleFunc("GET", "/v1/agents/new", v1.NewAgent, true) - mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true) + mux.HandleFunc("POST", "/v1/agents/verify", v1.VerifyNewAgent, true) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) mux.HandleFunc("GET", "/v1/cert/info", certapi.GetCertInfo, true) diff --git a/internal/api/v1/new_agent.go b/internal/api/v1/new_agent.go index 01ad1c9..2f1a124 100644 --- a/internal/api/v1/new_agent.go +++ b/internal/api/v1/new_agent.go @@ -95,7 +95,7 @@ func NewAgent(w http.ResponseWriter, r *http.Request) { }) } -func AddAgent(w http.ResponseWriter, r *http.Request) { +func VerifyNewAgent(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() clientPEMData, err := io.ReadAll(r.Body) if err != nil { @@ -114,7 +114,7 @@ func AddAgent(w http.ResponseWriter, r *http.Request) { return } - nRoutesAdded, err := config.GetInstance().AddAgent(data.Host, data.CA, data.Client) + nRoutesAdded, err := config.GetInstance().VerifyNewAgent(data.Host, data.CA, data.Client) if err != nil { gphttp.ClientError(w, err) return diff --git a/internal/config/agent_pool.go b/internal/config/agent_pool.go index 9884211..8f5f53d 100644 --- a/internal/config/agent_pool.go +++ b/internal/config/agent_pool.go @@ -1,9 +1,10 @@ package config import ( + "slices" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/utils/functional" ) @@ -30,7 +31,13 @@ func (cfg *Config) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, b return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost)) } -func (cfg *Config) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) { +func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) { + if slices.ContainsFunc(cfg.value.Providers.Agents, func(a *agent.AgentConfig) bool { + return a.Addr == host + }) { + return 0, gperr.New("agent already exists") + } + var agentCfg agent.AgentConfig agentCfg.Addr = host err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key) @@ -43,10 +50,10 @@ func (cfg *Config) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) if err := cfg.errIfExists(provider); err != nil { return 0, err } - provider.LoadRoutes() - provider.Start(cfg.Task()) - cfg.storeProvider(provider) - logging.Info().Msgf("Added agent %s with %d routes", host, provider.NumRoutes()) + err = provider.LoadRoutes() + if err != nil { + return 0, gperr.Wrap(err, "failed to load routes") + } return provider.NumRoutes(), nil } diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 76d5ae9..f8e693b 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -41,7 +41,7 @@ type ( RouteProviderList() []string Context() context.Context GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) - AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) + VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) ListAgents() []*agent.AgentConfig AutoCertProvider() *autocert.Provider }