diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index c5c5c59..23aba0c 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "strings" - "sync" "time" "github.com/rs/zerolog" @@ -18,21 +17,18 @@ import ( gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/pkg" "golang.org/x/net/context" ) -type ( - AgentConfig struct { - Addr string +type AgentConfig struct { + Addr string - httpClient *http.Client - tlsConfig *tls.Config - name string - l zerolog.Logger - } -) + httpClient *http.Client + tlsConfig *tls.Config + name string + l zerolog.Logger +} const ( EndpointVersion = "/version" @@ -54,42 +50,24 @@ const ( ) var ( - agents = functional.NewMapOf[string, *AgentConfig]() - agentMapMu sync.RWMutex -) - -var ( - HTTPProxyURL = types.MustParseURL(APIBaseURL + EndpointProxyHTTP) - HTTPProxyURLStripLen = len(APIEndpointBase + EndpointProxyHTTP) + HTTPProxyURL = types.MustParseURL(APIBaseURL + EndpointProxyHTTP) + HTTPProxyURLPrefixLen = len(APIEndpointBase + EndpointProxyHTTP) ) func IsDockerHostAgent(dockerHost string) bool { return strings.HasPrefix(dockerHost, FakeDockerHostPrefix) } -func GetAgentFromDockerHost(dockerHost string) (*AgentConfig, bool) { - if !IsDockerHostAgent(dockerHost) { - return nil, false - } - return agents.Load(dockerHost[FakeDockerHostPrefixLen:]) +func GetAgentAddrFromDockerHost(dockerHost string) string { + return dockerHost[FakeDockerHostPrefixLen:] } func (cfg *AgentConfig) FakeDockerHost() string { - return FakeDockerHostPrefix + cfg.Name() + return FakeDockerHostPrefix + cfg.Addr } func (cfg *AgentConfig) Parse(addr string) error { cfg.Addr = addr - return cfg.load() -} - -func (cfg *AgentConfig) errIfNameExists() E.Error { - agentMapMu.RLock() - defer agentMapMu.RUnlock() - agent, ok := agents.Load(cfg.Name()) - if ok { - return E.Errorf("agent with name %s (%s) already exists", cfg.Name(), agent.Addr) - } return nil } @@ -101,13 +79,7 @@ func checkVersion(a, b string) bool { return withoutBuildTime(a) == withoutBuildTime(b) } -func (cfg *AgentConfig) Remove() { - agentMapMu.Lock() - defer agentMapMu.Unlock() - agents.Delete(cfg.Name()) -} - -func (cfg *AgentConfig) load() E.Error { +func (cfg *AgentConfig) Start(parent task.Parent) E.Error { certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) if err != nil { if os.IsNotExist(err) { @@ -141,7 +113,7 @@ func (cfg *AgentConfig) load() E.Error { // create transport and http client cfg.httpClient = cfg.NewHTTPClient() - ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second) + ctx, cancel := context.WithTimeout(parent.Context(), 5*time.Second) defer cancel() // check agent version @@ -160,15 +132,10 @@ func (cfg *AgentConfig) load() E.Error { return E.Wrap(err) } - // check if agent name is already used cfg.name = string(name) - if err := cfg.errIfNameExists(); err != nil { - return err - } - cfg.l = logging.With().Str("agent", cfg.name).Logger() - agents.Store(cfg.name, cfg) + logging.Info().Msgf("agent %q started", cfg.name) return nil } @@ -195,7 +162,7 @@ func (cfg *AgentConfig) Name() string { } func (cfg *AgentConfig) String() string { - return "agent@" + cfg.Name() + return "agent@" + cfg.Addr } func (cfg *AgentConfig) MarshalJSON() ([]byte, error) { diff --git a/agent/pkg/agent/requests.go b/agent/pkg/agent/requests.go index 40ad158..a6efb5b 100644 --- a/agent/pkg/agent/requests.go +++ b/agent/pkg/agent/requests.go @@ -5,13 +5,11 @@ import ( "net/http" "github.com/coder/websocket" - "github.com/yusing/go-proxy/internal/logging" "golang.org/x/net/context" ) func (cfg *AgentConfig) Do(ctx context.Context, method, endpoint string, body io.Reader) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, method, APIBaseURL+endpoint, body) - logging.Debug().Msgf("request: %s %s", method, req.URL.String()) if err != nil { return nil, err } diff --git a/agent/pkg/handler/proxy_http.go b/agent/pkg/handler/proxy_http.go index 63f234b..d59b9e7 100644 --- a/agent/pkg/handler/proxy_http.go +++ b/agent/pkg/handler/proxy_http.go @@ -49,7 +49,7 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = "" r.URL.Host = "" - r.URL.Path = r.URL.Path[agent.HTTPProxyURLStripLen:] // strip the {API_BASE}/proxy/http prefix + r.URL.Path = r.URL.Path[agent.HTTPProxyURLPrefixLen:] // strip the {API_BASE}/proxy/http prefix r.RequestURI = r.URL.String() r.URL.Host = host r.URL.Scheme = scheme diff --git a/internal/config/config.go b/internal/config/config.go index 7ae78c4..8e64614 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,16 +2,18 @@ package config import ( "context" + "errors" "os" "strconv" "strings" "sync" "time" + "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/api" "github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/config/types" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/entrypoint" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" @@ -26,7 +28,7 @@ import ( ) type Config struct { - value *types.Config + value *config.Config providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider entrypoint *entrypoint.Entrypoint @@ -35,7 +37,6 @@ type Config struct { } var ( - instance *Config cfgWatcher watcher.Watcher reloadMu sync.Mutex ) @@ -49,15 +50,15 @@ Make sure you rename it back before next time you start.` You may run "ls-config" to show or dump the current config.` ) -var Validate = types.Validate +var Validate = config.Validate func GetInstance() *Config { - return instance + return config.GetInstance().(*Config) } func newConfig() *Config { return &Config{ - value: types.DefaultConfig(), + value: config.DefaultConfig(), providers: F.NewMapOf[string, *proxy.Provider](), entrypoint: entrypoint.NewEntrypoint(), task: task.RootTask("config", false), @@ -65,16 +66,17 @@ func newConfig() *Config { } func Load() (*Config, E.Error) { - if instance != nil { - return instance, nil + if config.HasInstance() { + panic(errors.New("config already loaded")) } - instance = newConfig() + cfg := newConfig() + config.SetInstance(cfg) cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName) - return instance, instance.load() + return cfg, cfg.load() } func MatchDomains() []string { - return instance.value.MatchDomains + return GetInstance().Value().MatchDomains } func WatchChanges() { @@ -122,14 +124,14 @@ func Reload() E.Error { // cancel all current subtasks -> wait // -> replace config -> start new subtasks - instance.task.Finish("config changed") - instance = newCfg - instance.Start(StartAllServers) + GetInstance().Task().Finish("config changed") + newCfg.Start(StartAllServers) + config.SetInstance(newCfg) return nil } -func (cfg *Config) Value() *types.Config { - return instance.value +func (cfg *Config) Value() *config.Config { + return cfg.value } func (cfg *Config) Reload() E.Error { @@ -137,7 +139,7 @@ func (cfg *Config) Reload() E.Error { } func (cfg *Config) AutoCertProvider() *autocert.Provider { - return instance.autocertProvider + return cfg.autocertProvider } func (cfg *Config) Task() *task.Task { @@ -217,7 +219,7 @@ func (cfg *Config) load() E.Error { E.LogFatal(errMsg, err) } - model := types.DefaultConfig() + model := config.DefaultConfig() if err := utils.DeserializeYAML(data, model); err != nil { E.LogFatal(errMsg, err) } @@ -260,39 +262,74 @@ func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err E.Err return } -func (cfg *Config) loadRouteProviders(providers *types.Providers) E.Error { +func (cfg *Config) errIfExists(p *proxy.Provider) E.Error { + if _, ok := cfg.providers.Load(p.String()); ok { + return E.Errorf("provider %s already exists", p.String()) + } + return nil +} + +func (cfg *Config) storeProvider(p *proxy.Provider) { + cfg.providers.Store(p.String(), p) +} + +func (cfg *Config) GetAgent(agentDockerHost string) (*agent.AgentConfig, bool) { + if !agent.IsDockerHostAgent(agentDockerHost) { + panic(errors.New("invalid use of GetAgent with docker host: " + agentDockerHost)) + } + key := "agent@" + agent.GetAgentAddrFromDockerHost(agentDockerHost) + p, ok := cfg.providers.Load(key) + if !ok { + return nil, false + } + return p.ProviderImpl.(*proxy.AgentProvider).AgentConfig, true +} + +func (cfg *Config) loadRouteProviders(providers *config.Providers) E.Error { errs := E.NewBuilder("route provider errors") results := E.NewBuilder("loaded route providers") - lenLongestName := 0 + for _, agent := range providers.Agents { + if err := agent.Start(cfg.task); err != nil { + errs.Add(err.Subject(agent.String())) + continue + } + p := proxy.NewAgentProvider(agent) + if err := cfg.errIfExists(p); err != nil { + errs.Add(err.Subject(p.String())) + continue + } + cfg.storeProvider(p) + } for _, filename := range providers.Files { p, err := proxy.NewFileProvider(filename) + if err == nil { + err = cfg.errIfExists(p) + } if err != nil { errs.Add(E.PrependSubject(filename, err)) continue } - cfg.providers.Store(p.String(), p) - if len(p.String()) > lenLongestName { - lenLongestName = len(p.String()) - } + cfg.storeProvider(p) } for name, dockerHost := range providers.Docker { - p, err := proxy.NewDockerProvider(name, dockerHost) - if err != nil { - errs.Add(E.PrependSubject(name, err)) + p := proxy.NewDockerProvider(name, dockerHost) + if err := cfg.errIfExists(p); err != nil { + errs.Add(err.Subject(p.String())) continue } - cfg.providers.Store(p.String(), p) - if len(p.String()) > lenLongestName { - lenLongestName = len(p.String()) - } - } - for _, agent := range providers.Agents { - cfg.providers.Store(agent.Name(), proxy.NewAgentProvider(&agent)) + cfg.storeProvider(p) } if cfg.providers.Size() == 0 { return nil } + + lenLongestName := 0 + cfg.providers.RangeAll(func(k string, _ *proxy.Provider) { + if len(k) > lenLongestName { + lenLongestName = len(k) + } + }) cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) { if err := p.LoadRoutes(); err != nil { errs.Add(err.Subject(p.String())) diff --git a/internal/config/types/config.go b/internal/config/types/config.go index b75d421..b836b5f 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -3,6 +3,7 @@ package types import ( "context" "regexp" + "sync" "github.com/go-playground/validator/v10" "github.com/yusing/go-proxy/agent/pkg/agent" @@ -25,8 +26,8 @@ type ( } Providers struct { Files []string `json:"include" yaml:"include,omitempty" validate:"dive,filepath"` - Docker map[string]string `json:"docker" yaml:"docker,omitempty" validate:"dive,unix_addr|url"` - Agents []agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` + Docker map[string]string `json:"docker" yaml:"docker,omitempty" validate:"non_empty_docker_keys,dive,unix_addr|url"` + Agents []*agent.AgentConfig `json:"agents" yaml:"agents,omitempty"` Notification []notif.NotificationConfig `json:"notification" yaml:"notification,omitempty"` } Entrypoint struct { @@ -40,9 +41,15 @@ type ( Statistics() map[string]any RouteProviderList() []string Context() context.Context + GetAgent(agentDockerHost string) (*agent.AgentConfig, bool) } ) +var ( + instance ConfigInstance + instanceMu sync.RWMutex +) + func DefaultConfig() *Config { return &Config{ TimeoutShutdown: 3, @@ -52,6 +59,24 @@ func DefaultConfig() *Config { } } +func GetInstance() ConfigInstance { + instanceMu.RLock() + defer instanceMu.RUnlock() + return instance +} + +func SetInstance(cfg ConfigInstance) { + instanceMu.Lock() + defer instanceMu.Unlock() + instance = cfg +} + +func HasInstance() bool { + instanceMu.RLock() + defer instanceMu.RUnlock() + return instance != nil +} + func Validate(data []byte) E.Error { var model Config return utils.DeserializeYAML(data, &model) @@ -70,4 +95,13 @@ func init() { } return true }) + utils.MustRegisterValidation("non_empty_docker_keys", func(fl validator.FieldLevel) bool { + m := fl.Field().Interface().(map[string]string) + for k := range m { + if k == "" { + return false + } + } + return true + }) } diff --git a/internal/docker/client.go b/internal/docker/client.go index 4bb2a6b..6ff7254 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "runtime/debug" "sync" "github.com/docker/cli/cli/connhelper" @@ -11,6 +12,7 @@ import ( "github.com/rs/zerolog" "github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/internal/common" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" @@ -84,9 +86,9 @@ func ConnectClient(host string) (*SharedClient, error) { var opt []client.Opt if agent.IsDockerHostAgent(host) { - cfg, ok := agent.GetAgentFromDockerHost(host) + cfg, ok := config.GetInstance().GetAgent(host) if !ok { - return nil, fmt.Errorf("agent not found for host: %s", host) + return nil, fmt.Errorf("agent %q not found\n%s", host, debug.Stack()) } opt = []client.Opt{ client.WithHost(agent.DockerHost), diff --git a/internal/docker/container.go b/internal/docker/container.go index 11b56d7..435048c 100644 --- a/internal/docker/container.go +++ b/internal/docker/container.go @@ -7,6 +7,7 @@ import ( "github.com/docker/docker/api/types" "github.com/yusing/go-proxy/agent/pkg/agent" + config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/logging" U "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -82,7 +83,11 @@ func FromDocker(c *types.Container, dockerHost string) (res *Container) { } if agent.IsDockerHostAgent(dockerHost) { - res.Agent, _ = agent.GetAgentFromDockerHost(dockerHost) + var ok bool + res.Agent, ok = config.GetInstance().GetAgent(dockerHost) + if !ok { + logging.Error().Msgf("agent %q not found", dockerHost) + } } res.setPrivateHostname(helper) diff --git a/internal/route/provider/agent.go b/internal/route/provider/agent.go index be2736f..85acec9 100644 --- a/internal/route/provider/agent.go +++ b/internal/route/provider/agent.go @@ -14,7 +14,7 @@ type AgentProvider struct { } func (p *AgentProvider) ShortName() string { - return p.Name() + return p.AgentConfig.Name() } func (p *AgentProvider) NewWatcher() watcher.Watcher { diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index bff489e..cac066e 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -39,8 +39,7 @@ func makeRoutes(cont *types.Container, dockerHostIP ...string) route.Routes { } func TestExplicitOnly(t *testing.T) { - p, err := NewDockerProvider("a!", "") - ExpectNoError(t, err) + p := NewDockerProvider("a!", "") ExpectTrue(t, p.IsExplicitOnly()) } diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 8cf4c03..78cb4a4 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -59,15 +59,11 @@ func NewFileProvider(filename string) (p *Provider, err error) { return } -func NewDockerProvider(name string, dockerHost string) (p *Provider, err error) { - if name == "" { - return nil, ErrEmptyProviderName - } - - p = newProvider(types.ProviderTypeDocker) +func NewDockerProvider(name string, dockerHost string) *Provider { + p := newProvider(types.ProviderTypeDocker) p.ProviderImpl = DockerProviderImpl(name, dockerHost) p.watcher = p.NewWatcher() - return + return p } func NewAgentProvider(cfg *agent.AgentConfig) *Provider { @@ -127,10 +123,6 @@ func (p *Provider) Start(parent task.Parent) E.Error { if err := errs.Error(); err != nil { return err.Subject(p.String()) } - - if p.t == types.ProviderTypeAgent { - t.OnCancel("remove agent", p.ProviderImpl.(*AgentProvider).AgentConfig.Remove) - } return nil }