refactor(provider): improve route handling

This commit is contained in:
yusing 2025-06-04 23:15:56 +08:00
parent 45e34d691a
commit b670cdbd49
3 changed files with 110 additions and 42 deletions

View file

@ -3,7 +3,7 @@ package provider
import ( import (
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/route"
"github.com/yusing/go-proxy/internal/route/provider/types" provider "github.com/yusing/go-proxy/internal/route/provider/types"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
eventsPkg "github.com/yusing/go-proxy/internal/watcher/events" eventsPkg "github.com/yusing/go-proxy/internal/watcher/events"
@ -23,7 +23,7 @@ func (p *Provider) newEventHandler() *EventHandler {
} }
func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) { func (handler *EventHandler) Handle(parent task.Parent, events []watcher.Event) {
oldRoutes := handler.provider.routes oldRoutes := handler.provider.lockCloneRoutes()
isForceReload := false isForceReload := false
for _, event := range events { for _, event := range events {
@ -68,10 +68,10 @@ func (handler *EventHandler) matchAny(events []watcher.Event, route *route.Route
func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool { func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool {
switch handler.provider.GetType() { switch handler.provider.GetType() {
case types.ProviderTypeDocker, types.ProviderTypeAgent: case provider.ProviderTypeDocker, provider.ProviderTypeAgent:
return route.Container.ContainerID == event.ActorID || return route.Container.ContainerID == event.ActorID ||
route.Container.ContainerName == event.ActorName route.Container.ContainerName == event.ActorName
case types.ProviderTypeFile: case provider.ProviderTypeFile:
return true return true
} }
// should never happen // should never happen
@ -86,12 +86,11 @@ func (handler *EventHandler) Add(parent task.Parent, route *route.Route) {
} }
func (handler *EventHandler) Remove(route *route.Route) { func (handler *EventHandler) Remove(route *route.Route) {
route.Finish("route removed") route.FinishAndWait("route removed")
delete(handler.provider.routes, route.Alias)
} }
func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) { func (handler *EventHandler) Update(parent task.Parent, oldRoute *route.Route, newRoute *route.Route) {
oldRoute.Finish("route update") oldRoute.FinishAndWait("route update")
err := handler.provider.startRoute(parent, newRoute) err := handler.provider.startRoute(parent, newRoute)
if err != nil { if err != nil {
handler.errs.Add(err.Subject("update")) handler.errs.Add(err.Subject("update"))

View file

@ -3,8 +3,9 @@ package provider
import ( import (
"errors" "errors"
"fmt" "fmt"
"maps"
"path" "path"
"slices" "sync"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -23,6 +24,8 @@ type (
ProviderImpl ProviderImpl
t provider.Type t provider.Type
routes route.Routes
routesMu sync.RWMutex
watcher W.Watcher watcher W.Watcher
} }
@ -42,6 +45,7 @@ const (
var ErrEmptyProviderName = errors.New("empty provider name") var ErrEmptyProviderName = errors.New("empty provider name")
var _ routes.Provider = (*Provider)(nil)
func newProvider(t provider.Type) *Provider { func newProvider(t provider.Type) *Provider {
return &Provider{t: t} return &Provider{t: t}
@ -88,41 +92,36 @@ func (p *Provider) MarshalText() ([]byte, error) {
return []byte(p.String()), nil return []byte(p.String()), nil
} }
func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error {
err := r.Start(parent)
if err != nil {
delete(p.routes, r.Alias)
routes.All.Del(r)
return err.Subject(r.Alias)
}
if conflict, added := routes.All.AddIfNotExists(r); !added {
delete(p.routes, r.Alias)
return gperr.Errorf("route %s already exists: from %s and %s", r.Alias, r.ProviderName(), conflict.ProviderName())
} else {
r.Task().OnCancel("remove_routes_from_all", func() {
routes.All.Del(r)
})
}
return nil
}
// Start implements task.TaskStarter. // Start implements task.TaskStarter.
func (p *Provider) Start(parent task.Parent) gperr.Error { func (p *Provider) Start(parent task.Parent) gperr.Error {
errs := gperr.NewBuilder("routes error")
errs.EnableConcurrency()
t := parent.Subtask("provider."+p.String(), false) t := parent.Subtask("provider."+p.String(), false)
routesTask := t.Subtask("routes", false) // no need to lock here because we are not modifying the routes map.
errs := gperr.NewBuilder("routes error") routeSlice := make([]*route.Route, 0, len(p.routes))
for _, r := range p.routes { for _, r := range p.routes {
errs.Add(p.startRoute(routesTask, r)) routeSlice = append(routeSlice, r)
} }
var wg sync.WaitGroup
for _, r := range routeSlice {
wg.Add(1)
go func(r *route.Route) {
defer wg.Done()
errs.Add(p.startRoute(t, r))
}(r)
}
wg.Wait()
eventQueue := events.NewEventQueue( eventQueue := events.NewEventQueue(
t.Subtask("event_queue", false), t.Subtask("event_queue", false),
providerEventFlushInterval, providerEventFlushInterval,
func(events []events.Event) { func(events []events.Event) {
handler := p.newEventHandler() handler := p.newEventHandler()
// routes' lifetime should follow the provider's lifetime // routes' lifetime should follow the provider's lifetime
handler.Handle(routesTask, events) handler.Handle(t, events)
handler.Log() handler.Log()
}, },
func(err gperr.Error) { func(err gperr.Error) {
@ -137,17 +136,52 @@ func (p *Provider) Start(parent task.Parent) gperr.Error {
return nil return nil
} }
func (p *Provider) IterRoutes(yield func(string, *route.Route) bool) { func (p *Provider) LoadRoutes() (err gperr.Error) {
for alias, r := range p.routes { p.routes, err = p.loadRoutes()
if !yield(alias, r) { return
}
func (p *Provider) NumRoutes() int {
return len(p.routes)
}
func (p *Provider) IterRoutes(yield func(string, routes.Route) bool) {
routes := p.lockCloneRoutes()
for alias, r := range routes {
if !yield(alias, r.Impl()) {
break break
} }
} }
} }
func (p *Provider) GetRoute(alias string) (r *route.Route, ok bool) { func (p *Provider) FindService(project, service string) (routes.Route, bool) {
r, ok = p.routes[alias] switch p.GetType() {
return case provider.ProviderTypeDocker, provider.ProviderTypeAgent:
default:
return nil, false
}
if project == "" || service == "" {
return nil, false
}
routes := p.lockCloneRoutes()
for _, r := range routes {
cont := r.ContainerInfo()
if cont.DockerComposeProject() != project {
continue
}
if cont.DockerComposeService() == service {
return r.Impl(), true
}
}
return nil, false
}
func (p *Provider) GetRoute(alias string) (routes.Route, bool) {
r, ok := p.lockGetRoute(alias)
if !ok {
return nil, false
}
return r.Impl(), true
} }
func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) { func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
@ -161,7 +195,7 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
// set alias and provider, then validate // set alias and provider, then validate
for alias, r := range routes { for alias, r := range routes {
r.Alias = alias r.Alias = alias
r.Provider = p.ShortName() r.SetProvider(p)
if err := r.Validate(); err != nil { if err := r.Validate(); err != nil {
errs.Add(err.Subject(alias)) errs.Add(err.Subject(alias))
delete(routes, alias) delete(routes, alias)
@ -172,11 +206,40 @@ func (p *Provider) loadRoutes() (routes route.Routes, err gperr.Error) {
return routes, errs.Error() return routes, errs.Error()
} }
func (p *Provider) LoadRoutes() (err gperr.Error) { func (p *Provider) startRoute(parent task.Parent, r *route.Route) gperr.Error {
p.routes, err = p.loadRoutes() err := r.Start(parent)
return if err != nil {
p.lockDeleteRoute(r.Alias)
return err.Subject(r.Alias)
}
p.lockAddRoute(r)
r.Task().OnCancel("remove_route_from_provider", func() {
p.lockDeleteRoute(r.Alias)
})
return nil
} }
func (p *Provider) NumRoutes() int { func (p *Provider) lockAddRoute(r *route.Route) {
return len(p.routes) p.routesMu.Lock()
defer p.routesMu.Unlock()
p.routes[r.Alias] = r
}
func (p *Provider) lockDeleteRoute(alias string) {
p.routesMu.Lock()
defer p.routesMu.Unlock()
delete(p.routes, alias)
}
func (p *Provider) lockGetRoute(alias string) (*route.Route, bool) {
p.routesMu.RLock()
defer p.routesMu.RUnlock()
r, ok := p.routes[alias]
return r, ok
}
func (p *Provider) lockCloneRoutes() route.Routes {
p.routesMu.RLock()
defer p.routesMu.RUnlock()
return maps.Clone(p.routes)
} }

View file

@ -57,4 +57,10 @@ type (
Route Route
net.Stream net.Stream
} }
Provider interface {
GetRoute(alias string) (r Route, ok bool)
IterRoutes(yield func(alias string, r Route) bool)
FindService(project, service string) (r Route, ok bool)
ShortName() string
}
) )