diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 1bf2998..7453678 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -3,7 +3,7 @@ package provider import ( "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/route" - "github.com/yusing/go-proxy/internal/route/provider/types" + provider "github.com/yusing/go-proxy/internal/route/provider/types" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher" eventsPkg "github.com/yusing/go-proxy/internal/watcher/events" @@ -23,7 +23,7 @@ func (p *Provider) newEventHandler() *EventHandler { } func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) { - oldRoutes := handler.provider.routes + oldRoutes := handler.provider.lockCloneRoutes() isForceReload := false for _, event := range events { @@ -68,10 +68,10 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { switch handler.provider.GetType() { - case types.ProviderTypeDocker, types.ProviderTypeAgent: + case provider.ProviderTypeDocker, provider.ProviderTypeAgent: return route.Container.ContainerID == event.ActorID || route.Container.ContainerName == event.ActorName - case types.ProviderTypeFile: + case provider.ProviderTypeFile: return true } // should never happen @@ -86,12 +86,11 @@ func (handler *EventHandler) Add(parent task.Parent, route *route.Route) { } func (handler *EventHandler) Remove(route *route.Route) { - route.Finish("route removed") - delete(handler.provider.routes, route.Alias) + route.FinishAndWait("route removed") } func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) { - oldRoute.Finish("route update") + oldRoute.FinishAndWait("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { handler.errs.Add(err.Subject("update")) diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index f1f65cc..af0f71e 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -3,8 +3,9 @@ package provider import ( "errors" "fmt" + "maps" "path" - "slices" + "sync" "time" "github.com/rs/zerolog" @@ -23,6 +24,8 @@ type ( ProviderImpl t provider.Type + routes route.Routes + routesMu sync.RWMutex watcher W.Watcher } @@ -42,6 +45,7 @@ const ( var ErrEmptyProviderName = errors.New("empty provider name") +var _ routes.Provider = (*Provider)(nil) func newProvider(t provider.Type) *Provider { return &Provider{t: t} @@ -88,41 +92,36 @@ func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error { - err := r.Start(parent) - if err != nil { - delete(p.routes, r.Alias) - routes.All.Del(r) - return err.Subject(r.Alias) - } - if conflict, added := routes.All.AddIfNotExists(r); !added { - delete(p.routes, r.Alias) - return gperr.Errorf("route %s already exists: from %s and %s", r.Alias, r.ProviderName(), conflict.ProviderName()) - } else { - r.Task().OnCancel("remove_routes_from_all", func() { - routes.All.Del(r) - }) - } - return nil -} - // Start implements task.TaskStarter. func (p *Provider) Start(parent task.Parent) gperr.Error { + errs := gperr.NewBuilder("routes error") + errs.EnableConcurrency() + t := parent.Subtask("provider."+p.String(), false) - routesTask := t.Subtask("routes", false) - errs := gperr.NewBuilder("routes error") + // no need to lock here because we are not modifying the routes map. + routeSlice := make([]*route.Route, 0, len(p.routes)) for _, r := range p.routes { - errs.Add(p.startRoute(routesTask, r)) + routeSlice = append(routeSlice, r) } + var wg sync.WaitGroup + for _, r := range routeSlice { + wg.Add(1) + go func(r *route.Route) { + defer wg.Done() + errs.Add(p.startRoute(t, r)) + }(r) + } + wg.Wait() + eventQueue := events.NewEventQueue( t.Subtask("event_queue", false), providerEventFlushInterval, func(events []events.Event) { handler := p.newEventHandler() // routes' lifetime should follow the provider's lifetime - handler.Handle(routesTask, events) + handler.Handle(t, events) handler.Log() }, func(err gperr.Error) { @@ -137,17 +136,52 @@ func (p *Provider) Start(parent task.Parent) gperr.Error { return nil } -func (p *Provider) IterRoutes(yield func(string, *route.Route) bool) { - for alias, r := range p.routes { - if !yield(alias, r) { +func (p *Provider) LoadRoutes() (err gperr.Error) { + p.routes, err = p.loadRoutes() + return +} + +func (p *Provider) NumRoutes() int { + return len(p.routes) +} + +func (p *Provider) IterRoutes(yield func(string, routes.Route) bool) { + routes := p.lockCloneRoutes() + for alias, r := range routes { + if !yield(alias, r.Impl()) { break } } } -func (p *Provider) GetRoute(alias string) (r *route.Route, ok bool) { - r, ok = p.routes[alias] - return +func (p *Provider) FindService(project, service string) (routes.Route, bool) { + switch p.GetType() { + case provider.ProviderTypeDocker, provider.ProviderTypeAgent: + default: + return nil, false + } + if project == "" || service == "" { + return nil, false + } + routes := p.lockCloneRoutes() + for _, r := range routes { + cont := r.ContainerInfo() + if cont.DockerComposeProject() != project { + continue + } + if cont.DockerComposeService() == service { + return r.Impl(), true + } + } + return nil, false +} + +func (p *Provider) GetRoute(alias string) (routes.Route, bool) { + r, ok := p.lockGetRoute(alias) + if !ok { + return nil, false + } + return r.Impl(), true } func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) { @@ -161,7 +195,7 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) { // set alias and provider, then validate for alias, r := range routes { r.Alias = alias - r.Provider = p.ShortName() + r.SetProvider(p) if err := r.Validate(); err != nil { errs.Add(err.Subject(alias)) delete(routes, alias) @@ -172,11 +206,40 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) { return routes, errs.Error() } -func (p *Provider) LoadRoutes() (err gperr.Error) { - p.routes, err = p.loadRoutes() - return +func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error { + err := r.Start(parent) + if err != nil { + p.lockDeleteRoute(r.Alias) + return err.Subject(r.Alias) + } + p.lockAddRoute(r) + r.Task().OnCancel("remove_route_from_provider", func() { + p.lockDeleteRoute(r.Alias) + }) + return nil } -func (p *Provider) NumRoutes() int { - return len(p.routes) +func (p *Provider) lockAddRoute(r *route.Route) { + p.routesMu.Lock() + defer p.routesMu.Unlock() + p.routes[r.Alias] = r +} + +func (p *Provider) lockDeleteRoute(alias string) { + p.routesMu.Lock() + defer p.routesMu.Unlock() + delete(p.routes, alias) +} + +func (p *Provider) lockGetRoute(alias string) (*route.Route, bool) { + p.routesMu.RLock() + defer p.routesMu.RUnlock() + r, ok := p.routes[alias] + return r, ok +} + +func (p *Provider) lockCloneRoutes() route.Routes { + p.routesMu.RLock() + defer p.routesMu.RUnlock() + return maps.Clone(p.routes) } diff --git a/internal/route/routes/route.go b/internal/route/routes/route.go index 6f10455..83a1f30 100644 --- a/internal/route/routes/route.go +++ b/internal/route/routes/route.go @@ -57,4 +57,10 @@ type ( Route net.Stream } + Provider interface { + GetRoute(alias string) (r Route, ok bool) + IterRoutes(yield func(alias string, r Route) bool) + FindService(project, service string) (r Route, ok bool) + ShortName() string + } )