diff --git a/internal/config/agent_pool.go b/internal/config/agent_pool.go new file mode 100644 index 0000000..8f5f53d --- /dev/null +++ b/internal/config/agent_pool.go @@ -0,0 +1,66 @@ +package config + +import ( + "slices" + + "github.com/yusing/go-proxy/agent/pkg/agent" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/route/provider" + "github.com/yusing/go-proxy/internal/utils/functional" +) + +var agentPool = functional.NewMapOf[string, *agent.AgentConfig]() + +func addAgent(agent *agent.AgentConfig) { + agentPool.Store(agent.Addr, agent) +} + +func removeAllAgents() { + agentPool.Clear() +} + +func GetAgent(addr string) (agent *agent.AgentConfig, ok bool) { + agent, ok = agentPool.Load(addr) + return +} + +func (cfg *Config) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) { + if !agent.IsDockerHostAgent(agentAddrOrDockerHost) { + return GetAgent(agentAddrOrDockerHost) + } + return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost)) +} + +func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) { + if slices.ContainsFunc(cfg.value.Providers.Agents, func(a *agent.AgentConfig) bool { + return a.Addr == host + }) { + return 0, gperr.New("agent already exists") + } + + var agentCfg agent.AgentConfig + agentCfg.Addr = host + err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key) + if err != nil { + return 0, gperr.Wrap(err, "failed to start agent") + } + addAgent(&agentCfg) + + provider := provider.NewAgentProvider(&agentCfg) + if err := cfg.errIfExists(provider); err != nil { + return 0, err + } + err = provider.LoadRoutes() + if err != nil { + return 0, gperr.Wrap(err, "failed to load routes") + } + return provider.NumRoutes(), nil +} + +func (cfg *Config) ListAgents() []*agent.AgentConfig { + agents := make([]*agent.AgentConfig, 0, agentPool.Size()) + agentPool.RangeAll(func(key string, value *agent.AgentConfig) { + agents = append(agents, value) + }) + return agents +} diff --git a/internal/config/config.go b/internal/config/config.go index 20d6603..5cc3bef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "context" + "errors" "os" "strconv" "strings" @@ -11,11 +12,11 @@ import ( "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config/types" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/entrypoint" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/net/http/server" + "github.com/yusing/go-proxy/internal/net/gphttp/server" "github.com/yusing/go-proxy/internal/notif" proxy "github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/task" @@ -26,7 +27,7 @@ import ( ) type Config struct { - value *types.Config + value *config.Config providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider entrypoint *entrypoint.Entrypoint @@ -35,7 +36,6 @@ type Config struct { } var ( - instance *Config cfgWatcher watcher.Watcher reloadMu sync.Mutex ) @@ -49,15 +49,11 @@ Make sure you rename it back before next time you start.` You may run "ls-config" to show or dump the current config.` ) -var Validate = types.Validate - -func GetInstance() *Config { - return instance -} +var Validate = config.Validate func newConfig() *Config { return &Config{ - value: types.DefaultConfig(), + value: config.DefaultConfig(), providers: F.NewMapOf[string, *proxy.Provider](), entrypoint: entrypoint.NewEntrypoint(), task: task.RootTask("config", false), @@ -65,16 +61,17 @@ func newConfig() *Config { } func Load() (*Config, gperr.Error) { - if instance != nil { - return instance, nil + if config.HasInstance() { + panic(errors.New("config already loaded")) } - instance = newConfig() + cfg := newConfig() + config.SetInstance(cfg) cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName) - return instance, instance.load() + return cfg, cfg.load() } func MatchDomains() []string { - return instance.value.MatchDomains + return config.GetInstance().Value().MatchDomains } func WatchChanges() { @@ -122,22 +119,25 @@ func Reload() gperr.Error { // cancel all current subtasks -> wait // -> replace config -> start new subtasks - instance.task.Finish("config changed") - instance = newCfg - instance.Start(StartAllServers) + config.GetInstance().(*Config).Task().Finish("config changed") + newCfg.Start(StartAllServers) + config.SetInstance(newCfg) return nil } -func (cfg *Config) Value() *types.Config { - return instance.value +func (cfg *Config) Value() *config.Config { + return cfg.value } func (cfg *Config) Reload() gperr.Error { return Reload() } +// AutoCertProvider returns the autocert provider. +// +// If the autocert provider is not configured, it returns nil. func (cfg *Config) AutoCertProvider() *autocert.Provider { - return instance.autocertProvider + return cfg.autocertProvider } func (cfg *Config) Task() *task.Task { @@ -217,7 +217,7 @@ func (cfg *Config) load() gperr.Error { gperr.LogFatal(errMsg, err) } - model := types.DefaultConfig() + model := config.DefaultConfig() if err := utils.DeserializeYAML(data, model); err != nil { gperr.LogFatal(errMsg, err) } @@ -260,31 +260,65 @@ func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err gperr return } -func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error { +func (cfg *Config) errIfExists(p *proxy.Provider) gperr.Error { + if _, ok := cfg.providers.Load(p.String()); ok { + return gperr.Errorf("provider %s already exists", p.String()) + } + return nil +} - lenLongestName := 0 +func (cfg *Config) storeProvider(p *proxy.Provider) { + cfg.providers.Store(p.String(), p) +} + +func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error { + errs := gperr.NewBuilder("route provider errors") + results := gperr.NewBuilder("loaded route providers") + + removeAllAgents() + + for _, agent := range providers.Agents { + if err := agent.Start(cfg.task); err != nil { + errs.Add(err.Subject(agent.String())) + continue + } + addAgent(agent) + p := proxy.NewAgentProvider(agent) + if err := cfg.errIfExists(p); err != nil { + errs.Add(err.Subject(p.String())) + continue + } + cfg.storeProvider(p) + } for _, filename := range providers.Files { p, err := proxy.NewFileProvider(filename) + if err == nil { + err = cfg.errIfExists(p) + } if err != nil { - errs.Add(E.PrependSubject(filename, err)) + errs.Add(gperr.PrependSubject(filename, err)) continue } - cfg.providers.Store(p.String(), p) - if len(p.String()) > lenLongestName { - lenLongestName = len(p.String()) - } + cfg.storeProvider(p) } for name, dockerHost := range providers.Docker { - p, err := proxy.NewDockerProvider(name, dockerHost) - if err != nil { - errs.Add(E.PrependSubject(name, err)) + p := proxy.NewDockerProvider(name, dockerHost) + if err := cfg.errIfExists(p); err != nil { + errs.Add(err.Subject(p.String())) continue } - cfg.providers.Store(p.String(), p) - if len(p.String()) > lenLongestName { - lenLongestName = len(p.String()) - } + cfg.storeProvider(p) } + if cfg.providers.Size() == 0 { + return nil + } + + lenLongestName := 0 + cfg.providers.RangeAll(func(k string, _ *proxy.Provider) { + if len(k) > lenLongestName { + lenLongestName = len(k) + } + }) cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { if err := p.LoadRoutes(); err != nil { errs.Add(err.Subject(p.String())) diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 7c8ed98..f8e693b 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -3,14 +3,15 @@ package types import ( "context" "regexp" + "sync" "github.com/go-playground/validator/v10" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/net/gphttp/accesslog" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/utils" - - E "github.com/yusing/go-proxy/internal/error" ) type ( @@ -23,9 +24,10 @@ type ( TimeoutShutdown int `json:"timeout_shutdown" validate:"gte=0"` } Providers struct { - Files []string `json:"include" validate:"dive,filepath"` - Docker map[string]string `json:"docker" validate:"dive,unix_addr|url"` - Notification []notif.NotificationConfig `json:"notification"` + Files []string `json:"include" yaml:"include,omitempty" validate:"dive,filepath"` + Docker map[string]string `json:"docker" yaml:"docker,omitempty" validate:"non_empty_docker_keys,dive,unix_addr|url"` + Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` + Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` } Entrypoint struct { Middlewares []map[string]any `json:"middlewares"` @@ -38,9 +40,18 @@ type ( Statistics() map[string]any RouteProviderList() []string Context() context.Context + GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) + VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) + ListAgents() []*agent.AgentConfig + AutoCertProvider() *autocert.Provider } ) +var ( + instance ConfigInstance + instanceMu sync.RWMutex +) + func DefaultConfig() *Config { return &Config{ TimeoutShutdown: 3, @@ -50,7 +61,25 @@ func DefaultConfig() *Config { } } -func Validate(data []byte) E.Error { +func GetInstance() ConfigInstance { + instanceMu.RLock() + defer instanceMu.RUnlock() + return instance +} + +func SetInstance(cfg ConfigInstance) { + instanceMu.Lock() + defer instanceMu.Unlock() + instance = cfg +} + +func HasInstance() bool { + instanceMu.RLock() + defer instanceMu.RUnlock() + return instance != nil +} + +func Validate(data []byte) gperr.Error { var model Config return utils.DeserializeYAML(data, &model) } @@ -68,4 +97,13 @@ func init() { } return true }) + utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool { + m := fl.Field().Interface().(map[string]string) + for k := range m { + if k == "" { + return false + } + } + return true + }) }