diff --git a/docs/docker.md b/docs/docker.md index d49f473..07334fe 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -250,7 +250,6 @@ services: ports: - 25565 labels: - - proxy.mc.scheme=tcp - proxy.mc.port=20001:25565 environment: - EULA=TRUE diff --git a/docs/docker_socket_proxy.md b/docs/docker_socket_proxy.md index 2c0e4db..c615371 100644 --- a/docs/docker_socket_proxy.md +++ b/docs/docker_socket_proxy.md @@ -4,56 +4,37 @@ For docker client on other machine, set this up, then add `name: tcp://:2375:2375 ``` ```yml # config.yml on go-proxy machine autocert: - ... # your config + ... # your config providers: - include: - ... - docker: - ... - server1: tcp://:2375 + include: + ... + docker: + ... + server1: tcp://:2375 ``` diff --git a/src/autocert/provider.go b/src/autocert/provider.go index e78116a..f7d2b3a 100644 --- a/src/autocert/provider.go +++ b/src/autocert/provider.go @@ -59,7 +59,7 @@ func (p *Provider) ObtainCert() (res E.NestedError) { defer b.To(&res) if p.cfg.Provider == ProviderLocal { - b.Addf("provider is set to %q", ProviderLocal) + b.Addf("provider is set to %q", ProviderLocal).WithSeverity(E.SeverityWarning) return } diff --git a/src/autocert/setup.go b/src/autocert/setup.go new file mode 100644 index 0000000..72fe3f4 --- /dev/null +++ b/src/autocert/setup.go @@ -0,0 +1,29 @@ +package autocert + +import ( + "context" + "os" + + E "github.com/yusing/go-proxy/error" +) + +func (p *Provider) Setup(ctx context.Context) (err E.NestedError) { + if err = p.LoadCert(); err != nil { + if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist + return err + } + logger.Debug("obtaining cert due to error loading cert") + if err = p.ObtainCert(); err != nil { + return err.Warn() + } + } + + go p.ScheduleRenewal(ctx) + + for _, expiry := range p.GetExpiries() { + logger.Infof("certificate expire on %s", expiry) + break + } + + return nil +} diff --git a/src/common/args.go b/src/common/args.go index 806d210..c8d0c2c 100644 --- a/src/common/args.go +++ b/src/common/args.go @@ -12,12 +12,13 @@ type Args struct { } const ( - CommandStart = "" - CommandValidate = "validate" - CommandListConfigs = "ls-config" - CommandListRoutes = "ls-routes" - CommandReload = "reload" - CommandDebugListEntries = "debug-ls-entries" + CommandStart = "" + CommandValidate = "validate" + CommandListConfigs = "ls-config" + CommandListRoutes = "ls-routes" + CommandReload = "reload" + CommandDebugListEntries = "debug-ls-entries" + CommandDebugListProviders = "debug-ls-providers" ) var ValidCommands = []string{ @@ -27,6 +28,7 @@ var ValidCommands = []string{ CommandListRoutes, CommandReload, CommandDebugListEntries, + CommandDebugListProviders, } func GetArgs() Args { diff --git a/src/common/constants.go b/src/common/constants.go index d871988..5dc0fbe 100644 --- a/src/common/constants.go +++ b/src/common/constants.go @@ -30,9 +30,9 @@ const ( ) const ( - SchemaBasePath = "schema/" - ConfigSchemaPath = SchemaBasePath + "config.schema.json" - ProvidersSchemaPath = SchemaBasePath + "providers.schema.json" + SchemaBasePath = "schema/" + ConfigSchemaPath = SchemaBasePath + "config.schema.json" + FileProviderSchemaPath = SchemaBasePath + "providers.schema.json" ) const DockerHostFromEnv = "$DOCKER_HOST" diff --git a/src/config/config.go b/src/config/config.go index 1fdbcfc..e4864ae 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -14,6 +14,7 @@ import ( U "github.com/yusing/go-proxy/utils" F "github.com/yusing/go-proxy/utils/functional" W "github.com/yusing/go-proxy/watcher" + "github.com/yusing/go-proxy/watcher/events" "gopkg.in/yaml.v3" ) @@ -94,7 +95,7 @@ func (cfg *Config) WatchChanges() { case <-cfg.watcherCtx.Done(): return case event := <-eventCh: - if event.Action.IsDelete() { + if event.Action == events.ActionFileDeleted { cfg.stopProviders() } else { cfg.reloadReq <- struct{}{} @@ -107,71 +108,6 @@ func (cfg *Config) WatchChanges() { }() } -func (cfg *Config) FindRoute(alias string) R.Route { - return F.MapFind(cfg.proxyProviders, - func(p *PR.Provider) (R.Route, bool) { - if route, ok := p.GetRoute(alias); ok { - return route, true - } - return nil, false - }, - ) -} - -func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { - routes := make(map[string]U.SerializedObject) - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { - obj, err := U.Serialize(r) - if err.HasError() { - cfg.l.Error(err) - return - } - obj["provider"] = p.GetName() - obj["type"] = string(r.Type()) - routes[alias] = obj - }) - return routes -} - -func (cfg *Config) Statistics() map[string]any { - nTotalStreams := 0 - nTotalRPs := 0 - providerStats := make(map[string]any) - - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { - s, ok := providerStats[p.GetName()] - if !ok { - s = make(map[string]int) - } - - stats := s.(map[string]int) - switch r.Type() { - case R.RouteTypeStream: - stats["num_streams"]++ - nTotalStreams++ - case R.RouteTypeReverseProxy: - stats["num_reverse_proxies"]++ - nTotalRPs++ - default: - panic("bug: should not reach here") - } - }) - - return map[string]any{ - "num_total_streams": nTotalStreams, - "num_total_reverse_proxies": nTotalRPs, - "providers": providerStats, - } -} - -func (cfg *Config) DumpEntries() map[string]*M.RawEntry { - entries := make(map[string]*M.RawEntry) - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { - entries[alias] = r.Entry() - }) - return entries -} - func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) { cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) { p.RangeRoutes(func(a string, r R.Route) { @@ -259,7 +195,7 @@ func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError } func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { - errors := E.NewBuilder("cannot %s these providers", action) + errors := E.NewBuilder("errors in %s these providers", action) cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { if err := do(p); err.HasError() { diff --git a/src/config/query.go b/src/config/query.go new file mode 100644 index 0000000..47e15e5 --- /dev/null +++ b/src/config/query.go @@ -0,0 +1,82 @@ +package config + +import ( + M "github.com/yusing/go-proxy/models" + PR "github.com/yusing/go-proxy/proxy/provider" + R "github.com/yusing/go-proxy/route" + U "github.com/yusing/go-proxy/utils" + F "github.com/yusing/go-proxy/utils/functional" +) + +func (cfg *Config) DumpEntries() map[string]*M.RawEntry { + entries := make(map[string]*M.RawEntry) + cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + entries[alias] = r.Entry() + }) + return entries +} + +func (cfg *Config) DumpProviders() map[string]*PR.Provider { + entries := make(map[string]*PR.Provider) + cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { + entries[name] = p + }) + return entries +} + +func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { + routes := make(map[string]U.SerializedObject) + cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + obj, err := U.Serialize(r) + if err.HasError() { + cfg.l.Error(err) + return + } + obj["provider"] = p.GetName() + obj["type"] = string(r.Type()) + routes[alias] = obj + }) + return routes +} + +func (cfg *Config) Statistics() map[string]any { + nTotalStreams := 0 + nTotalRPs := 0 + providerStats := make(map[string]any) + + cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + s, ok := providerStats[p.GetName()] + if !ok { + s = make(map[string]int) + } + + stats := s.(map[string]int) + switch r.Type() { + case R.RouteTypeStream: + stats["num_streams"]++ + nTotalStreams++ + case R.RouteTypeReverseProxy: + stats["num_reverse_proxies"]++ + nTotalRPs++ + default: + panic("bug: should not reach here") + } + }) + + return map[string]any{ + "num_total_streams": nTotalStreams, + "num_total_reverse_proxies": nTotalRPs, + "providers": providerStats, + } +} + +func (cfg *Config) FindRoute(alias string) R.Route { + return F.MapFind(cfg.proxyProviders, + func(p *PR.Provider) (R.Route, bool) { + if route, ok := p.GetRoute(alias); ok { + return route, true + } + return nil, false + }, + ) +} diff --git a/src/docker/client.go b/src/docker/client.go index 70b4798..89c16f5 100644 --- a/src/docker/client.go +++ b/src/docker/client.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/common" E "github.com/yusing/go-proxy/error" + F "github.com/yusing/go-proxy/utils/functional" ) type Client struct { @@ -48,9 +49,7 @@ func (c *Client) Close() error { return nil } - clientMapMu.Lock() - defer clientMapMu.Unlock() - delete(clientMap, c.key) + clientMap.Delete(c.key) client := c.Client c.Client = nil @@ -78,7 +77,7 @@ func ConnectClient(host string) (Client, E.NestedError) { defer clientMapMu.Unlock() // check if client exists - if client, ok := clientMap[host]; ok { + if client, ok := clientMap.Load(host); ok { client.refCount.Add(1) return client, nil } @@ -129,23 +128,22 @@ func ConnectClient(host string) (Client, E.NestedError) { c.refCount.Add(1) c.l.Debugf("client connected") - clientMap[host] = c - return clientMap[host], nil + clientMap.Store(host, c) + return c, nil } func CloseAllClients() { - clientMapMu.Lock() - defer clientMapMu.Unlock() - for _, client := range clientMap { - client.Close() - } - clientMap = make(map[string]Client) + clientMap.RangeAll(func(_ string, c Client) { + c.Client.Close() + }) + clientMap.Clear() logger.Debug("closed all clients") } var ( - clientMap map[string]Client = make(map[string]Client) - clientMapMu sync.Mutex + clientMap F.Map[string, Client] = F.NewMapOf[string, Client]() + clientMapMu sync.Mutex + clientOptEnvHost = []client.Opt{ client.WithHostFromEnv(), client.WithAPIVersionNegotiation(), diff --git a/src/docker/idlewatcher/html/loading_page.html b/src/docker/idlewatcher/html/loading_page.html new file mode 100644 index 0000000..d8eaf00 --- /dev/null +++ b/src/docker/idlewatcher/html/loading_page.html @@ -0,0 +1,87 @@ + + + + + + {{.Title}} + + + + +
+
{{.Message}}
+ + diff --git a/src/docker/idlewatcher/http.go b/src/docker/idlewatcher/http.go new file mode 100644 index 0000000..c220366 --- /dev/null +++ b/src/docker/idlewatcher/http.go @@ -0,0 +1,93 @@ +package idlewatcher + +import ( + "bytes" + _ "embed" + "fmt" + "io" + "net/http" + "strings" + "text/template" +) + +type templateData struct { + Title string + Message string + RequestHeaders http.Header + SpinnerClass string +} + +//go:embed html/loading_page.html +var loadingPage []byte +var loadingPageTmpl = func() *template.Template { + tmpl, err := template.New("loading").Parse(string(loadingPage)) + if err != nil { + panic(err) + } + return tmpl +}() + +const ( + htmlContentType = "text/html; charset=utf-8" + + errPrefix = "\u1000" + + headerGoProxyTargetURL = "X-GoProxy-Target" + headerContentType = "Content-Type" + + spinnerClassSpinner = "spinner" + spinnerClassErrorSign = "error" +) + +func (w *watcher) makeSuccResp(redirectURL string, resp *http.Response) (*http.Response, error) { + h := make(http.Header) + h.Set("Location", redirectURL) + h.Set("Content-Length", "0") + h.Set(headerContentType, htmlContentType) + return &http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: h, + Body: http.NoBody, + TLS: resp.TLS, + }, nil +} + +func (w *watcher) makeErrResp(errFmt string, args ...any) (*http.Response, error) { + return w.makeResp(errPrefix+errFmt, args...) +} + +func (w *watcher) makeResp(format string, args ...any) (*http.Response, error) { + msg := fmt.Sprintf(format, args...) + + data := new(templateData) + data.Title = w.ContainerName + data.Message = strings.ReplaceAll(msg, "\n", "
") + data.Message = strings.ReplaceAll(data.Message, " ", " ") + data.RequestHeaders = make(http.Header) + data.RequestHeaders.Add(headerGoProxyTargetURL, "window.location.href") + if strings.HasPrefix(data.Message, errPrefix) { + data.Message = strings.TrimLeft(data.Message, errPrefix) + data.SpinnerClass = spinnerClassErrorSign + } else { + data.SpinnerClass = spinnerClassSpinner + } + + buf := bytes.NewBuffer(make([]byte, 128)) // more than enough + err := loadingPageTmpl.Execute(buf, data) + if err != nil { // should never happen + panic(err) + } + return &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + headerContentType: {htmlContentType}, + "Cache-Control": { + "no-cache", + "no-store", + "must-revalidate", + }, + }, + Body: io.NopCloser(buf), + ContentLength: int64(buf.Len()), + }, nil +} diff --git a/src/docker/idlewatcher/round_trip.go b/src/docker/idlewatcher/round_trip.go index dc352d1..f4444fa 100644 --- a/src/docker/idlewatcher/round_trip.go +++ b/src/docker/idlewatcher/round_trip.go @@ -1,6 +1,10 @@ package idlewatcher -import "net/http" +import ( + "context" + "net/http" + "time" +) type ( roundTripper struct { @@ -12,3 +16,63 @@ type ( func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return rt.patched(req) } + +func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) { + // target site is ready, passthrough + if w.ready.Load() { + return origRoundTrip(req) + } + + // wake the container + w.wakeCh <- struct{}{} + + // initial request + targetUrl := req.Header.Get(headerGoProxyTargetURL) + if targetUrl == "" { + return w.makeResp( + "%s is starting... Please wait", + w.ContainerName, + ) + } + + w.l.Debug("serving event") + + // stream request + rtDone := make(chan *http.Response, 1) + ctx, cancel := context.WithTimeout(req.Context(), w.WakeTimeout) + defer cancel() + + // loop original round trip until success in a goroutine + go func() { + for { + select { + case <-ctx.Done(): + return + case <-w.ctx.Done(): + return + default: + resp, err := origRoundTrip(req) + if err == nil { + w.ready.Store(true) + rtDone <- resp + return + } + time.Sleep(time.Millisecond * 200) + } + } + }() + + for { + select { + case resp := <-rtDone: + return w.makeSuccResp(targetUrl, resp) + case <-ctx.Done(): + if ctx.Err() == context.DeadlineExceeded { + return w.makeErrResp("Timed out waiting for %s to fully wake", w.ContainerName) + } + return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error()) + case <-w.ctx.Done(): + return w.makeErrResp("idlewatcher has stopped\n%s", w.ctx.Err().Error()) + } + } +} diff --git a/src/docker/idlewatcher/watcher.go b/src/docker/idlewatcher/watcher.go index a42a2e8..cdb158e 100644 --- a/src/docker/idlewatcher/watcher.go +++ b/src/docker/idlewatcher/watcher.go @@ -1,9 +1,7 @@ package idlewatcher import ( - "bytes" "context" - "io" "net/http" "sync" "sync/atomic" @@ -16,33 +14,45 @@ import ( P "github.com/yusing/go-proxy/proxy" PT "github.com/yusing/go-proxy/proxy/fields" W "github.com/yusing/go-proxy/watcher" - event "github.com/yusing/go-proxy/watcher/events" ) -type watcher struct { - *P.ReverseProxyEntry - - client D.Client - - refCount atomic.Int32 - - stopByMethod StopCallback - wakeCh chan struct{} - wakeDone chan E.NestedError - running atomic.Bool - - ctx context.Context - cancel context.CancelFunc - - l logrus.FieldLogger -} - type ( + watcher struct { + *P.ReverseProxyEntry + + client D.Client + + ready atomic.Bool // whether the site is ready to accept connection + stopByMethod StopCallback // send a docker command w.r.t. `stop_method` + + wakeCh chan struct{} + wakeDone chan E.NestedError + + ctx context.Context + cancel context.CancelFunc + refCount *sync.WaitGroup + + l logrus.FieldLogger + } + WakeDone <-chan error WakeFunc func() WakeDone StopCallback func() E.NestedError ) +var ( + mainLoopCtx context.Context + mainLoopCancel context.CancelFunc + mainLoopWg sync.WaitGroup + + watcherMap = make(map[string]*watcher) + watcherMapMu sync.Mutex + + newWatcherCh = make(chan *watcher) + + logger = logrus.WithField("module", "idle_watcher") +) + func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { failure := E.Failure("idle_watcher register") @@ -67,12 +77,12 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { w := &watcher{ ReverseProxyEntry: entry, client: client, + refCount: &sync.WaitGroup{}, wakeCh: make(chan struct{}, 1), wakeDone: make(chan E.NestedError, 1), l: logger.WithField("container", entry.ContainerName), } w.refCount.Add(1) - w.running.Store(entry.ContainerRunning) w.stopByMethod = w.getStopCallback() watcherMap[w.ContainerName] = w @@ -84,20 +94,9 @@ func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { return w, nil } -// If the container is not registered, this is no-op func Unregister(containerName string) { - watcherMapMu.Lock() - defer watcherMapMu.Unlock() - if w, ok := watcherMap[containerName]; ok { - if w.refCount.Add(-1) > 0 { - return - } - if w.cancel != nil { - w.cancel() - } - w.client.Close() - delete(watcherMap, containerName) + w.refCount.Add(-1) } } @@ -107,8 +106,6 @@ func Start() { mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background()) - defer mainLoopWg.Wait() - for { select { case <-mainLoopCtx.Done(): @@ -117,8 +114,11 @@ func Start() { w.l.Debug("registered") mainLoopWg.Add(1) go func() { - w.watch() - Unregister(w.ContainerName) + w.watchUntilCancel() + w.refCount.Wait() // wait for 0 ref count + + w.client.Close() + delete(watcherMap, w.ContainerName) w.l.Debug("unregistered") mainLoopWg.Done() }() @@ -137,31 +137,6 @@ func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper { }} } -func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) { - w.wakeCh <- struct{}{} - - if w.running.Load() { - return origRoundTrip(req) - } - timeout := time.After(w.WakeTimeout) - - for { - if w.running.Load() { - return origRoundTrip(req) - } - select { - case <-req.Context().Done(): - return nil, req.Context().Err() - case err := <-w.wakeDone: - if err != nil { - return nil, err.Error() - } - case <-timeout: - return getLoadingResponse(), nil - } - } -} - func (w *watcher) containerStop() error { return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{ Signal: string(w.StopSignal), @@ -205,7 +180,6 @@ func (w *watcher) wakeIfStopped() E.NestedError { case "paused": return E.From(w.containerUnpause()) case "running": - w.running.Store(true) return nil default: return E.Unexpected("container state", status) @@ -236,15 +210,12 @@ func (w *watcher) getStopCallback() StopCallback { } } -func (w *watcher) watch() { - watcherCtx, watcherCancel := context.WithCancel(context.Background()) - w.ctx = watcherCtx - w.cancel = watcherCancel - - dockerWatcher := W.NewDockerWatcherWithClient(w.client) - +func (w *watcher) watchUntilCancel() { defer close(w.wakeCh) + w.ctx, w.cancel = context.WithCancel(context.Background()) + + dockerWatcher := W.NewDockerWatcherWithClient(w.client) dockerEventCh, dockerEventErrCh := dockerWatcher.EventsWithOptions(w.ctx, W.DockerListOptions{ Filters: W.NewDockerFilter( W.DockerFilterContainer, @@ -265,7 +236,7 @@ func (w *watcher) watch() { select { case <-mainLoopCtx.Done(): w.cancel() - case <-watcherCtx.Done(): + case <-w.ctx.Done(): w.l.Debug("stopped") return case err := <-dockerEventErrCh: @@ -273,16 +244,18 @@ func (w *watcher) watch() { w.l.Error(E.FailWith("docker watcher", err)) } case e := <-dockerEventCh: - switch e.Action { - case event.ActionDockerStartUnpause: - w.running.Store(true) - w.l.Infof("%s %s", e.ActorName, e.Action) - case event.ActionDockerStopPause: - w.running.Store(false) - w.l.Infof("%s %s", e.ActorName, e.Action) + switch { + // create / start / unpause + case e.Action.IsContainerWake(): + ticker.Reset(w.IdleTimeout) + w.l.Info(e) + default: // stop / pause / kill + ticker.Stop() + w.ready.Store(false) + w.l.Info(e) } case <-ticker.C: - w.l.Debug("timeout") + w.l.Debug("idle timeout") ticker.Stop() if err := w.stopByMethod(); err != nil && err.IsNot(context.Canceled) { w.l.Error(E.FailWith("stop", err).Extraf("stop method: %s", w.StopMethod)) @@ -301,57 +274,3 @@ func (w *watcher) watch() { } } } - -func getLoadingResponse() *http.Response { - return &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Content-Type": {"text/html"}, - "Cache-Control": { - "no-cache", - "no-store", - "must-revalidate", - }, - }, - Body: io.NopCloser(bytes.NewReader((loadingPage))), - ContentLength: int64(len(loadingPage)), - } -} - -var ( - mainLoopCtx context.Context - mainLoopCancel context.CancelFunc - mainLoopWg sync.WaitGroup - - watcherMap = make(map[string]*watcher) - watcherMapMu sync.Mutex - - newWatcherCh = make(chan *watcher) - - logger = logrus.WithField("module", "idle_watcher") - - loadingPage = []byte(` - - - - - - Loading... - - - -

Container is starting... Please wait

- - -`[1:]) -) diff --git a/src/error/builder.go b/src/error/builder.go index cfd29be..38a1bf2 100644 --- a/src/error/builder.go +++ b/src/error/builder.go @@ -25,6 +25,7 @@ func NewBuilder(format string, args ...any) Builder { func (b Builder) Add(err NestedError) Builder { if err != nil { b.Lock() + // TODO: if err severity is higher than b.severity, update b.severity b.errors = append(b.errors, err) b.Unlock() } diff --git a/src/error/error.go b/src/error/error.go index cc3934a..32ad66d 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -18,8 +18,8 @@ type ( ) const ( - SeverityFatal Severity = iota - SeverityWarning + SeverityWarning Severity = iota + SeverityFatal ) func From(err error) NestedError { diff --git a/src/error/errors.go b/src/error/errors.go index 14b3413..a430a1d 100644 --- a/src/error/errors.go +++ b/src/error/errors.go @@ -20,7 +20,7 @@ func Failure(what string) NestedError { } func FailedWhy(what string, why string) NestedError { - return errorf("%s %w because %s", what, ErrFailure, why) + return Failure(what).With(why) } func FailWith(what string, err any) NestedError { diff --git a/src/main.go b/src/main.go index 7bcb06a..5384d49 100755 --- a/src/main.go +++ b/src/main.go @@ -8,6 +8,9 @@ import ( "net/http" "os" "os/signal" + "reflect" + "runtime" + "strings" "sync" "syscall" "time" @@ -28,6 +31,7 @@ import ( func main() { args := common.GetArgs() l := logrus.WithField("module", "main") + onShutdown := F.NewSlice[func()]() if common.IsDebug { logrus.SetLevel(logrus.DebugLevel) @@ -40,20 +44,18 @@ func main() { DisableSorting: true, DisableLevelTruncation: true, FullTimestamp: true, - ForceColors: true, TimestampFormat: "01-02 15:04:05", }) } if args.Command == common.CommandReload { if err := apiUtils.ReloadServer(); err.HasError() { - l.Fatal(err) + log.Fatal(err) } + log.Print("ok") return } - onShutdown := F.NewSlice[func()]() - // exit if only validate config if args.Command == common.CommandValidate { data, err := os.ReadFile(common.ConfigPath) @@ -72,19 +74,19 @@ func main() { log.Fatal(err) } - if args.Command == common.CommandListConfigs { + switch args.Command { + case common.CommandListConfigs: printJSON(cfg.Value()) return - } - - if args.Command == common.CommandListRoutes { + case common.CommandListRoutes: printJSON(cfg.RoutesByAlias()) return - } - - if args.Command == common.CommandDebugListEntries { + case common.CommandDebugListEntries: printJSON(cfg.DumpEntries()) return + case common.CommandDebugListProviders: + printJSON(cfg.DumpProviders()) + return } cfg.StartProxyProviders() @@ -106,25 +108,14 @@ func main() { autocert := cfg.GetAutoCertProvider() if autocert != nil { - if err = autocert.LoadCert(); err.HasError() { - if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist - l.Error(err) - } - l.Debug("obtaining cert due to error loading cert") - if err = autocert.ObtainCert(); err.HasError() { - l.Warn(err) - } - } - - if err.NoError() { - ctx, certRenewalCancel := context.WithCancel(context.Background()) - go autocert.ScheduleRenewal(ctx) - onShutdown.Add(certRenewalCancel) - } - - for _, expiry := range autocert.GetExpiries() { - l.Infof("certificate expire on %s", expiry) - break + ctx, cancel := context.WithCancel(context.Background()) + if err = autocert.Setup(ctx); err != nil && err.IsWarning() { + cancel() + l.Warn(err) + } else if err.IsFatal() { + l.Fatal(err) + } else { + onShutdown.Add(cancel) } } else { l.Info("autocert not configured") @@ -165,7 +156,9 @@ func main() { wg.Add(onShutdown.Size()) onShutdown.ForEach(func(f func()) { go func() { + l.Debugf("waiting for %s to complete...", funcName(f)) f() + l.Debugf("%s done", funcName(f)) wg.Done() }() }) @@ -180,9 +173,17 @@ func main() { logrus.Info("shutdown complete") case <-timeout: logrus.Info("timeout waiting for shutdown") + onShutdown.ForEach(func(f func()) { + l.Warnf("%s() is still running", funcName(f)) + }) } } +func funcName(f func()) string { + parts := strings.Split(runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name(), "/go-proxy/") + return parts[len(parts)-1] +} + func printJSON(obj any) { j, err := E.Check(json.Marshal(obj)) if err.HasError() { diff --git a/src/proxy/fields/stream_scheme.go b/src/proxy/fields/stream_scheme.go index 3287ab7..318ad2e 100644 --- a/src/proxy/fields/stream_scheme.go +++ b/src/proxy/fields/stream_scheme.go @@ -32,7 +32,7 @@ func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { } func (s StreamScheme) String() string { - return fmt.Sprintf("%s -> %s", s.ListeningScheme, s.ProxyScheme) + return fmt.Sprintf("%s:%s", s.ListeningScheme, s.ProxyScheme) } // IsCoherent checks if the ListeningScheme and ProxyScheme of the StreamScheme are equal. diff --git a/src/proxy/provider/docker_provider.go b/src/proxy/provider/docker_provider.go index dbfb2ec..45a04aa 100755 --- a/src/proxy/provider/docker_provider.go +++ b/src/proxy/provider/docker_provider.go @@ -1,6 +1,7 @@ package provider import ( + "fmt" "regexp" "strconv" "strings" @@ -26,6 +27,10 @@ func DockerProviderImpl(dockerHost string) (ProviderImpl, E.NestedError) { return &DockerProvider{dockerHost: dockerHost, hostname: hostname}, nil } +func (p *DockerProvider) String() string { + return fmt.Sprintf("docker:%s", p.dockerHost) +} + func (p *DockerProvider) NewWatcher() W.Watcher { return W.NewDockerWatcher(p.dockerHost) } @@ -145,7 +150,7 @@ func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.Ra entryPortSplit := strings.Split(entry.Port, ":") if len(entryPortSplit) == 2 && entryPortSplit[1] == containerPort { entryPortSplit[1] = publicPort - } else if entryPortSplit[0] == containerPort { + } else if len(entryPortSplit) == 1 && entryPortSplit[0] == containerPort { entryPortSplit[0] = publicPort } entry.Port = strings.Join(entryPortSplit, ":") diff --git a/src/proxy/provider/file_provider.go b/src/proxy/provider/file_provider.go index ebf6735..9216c3a 100644 --- a/src/proxy/provider/file_provider.go +++ b/src/proxy/provider/file_provider.go @@ -35,7 +35,11 @@ func FileProviderImpl(filename string) (ProviderImpl, E.NestedError) { } func Validate(data []byte) E.NestedError { - return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data) + return U.ValidateYaml(U.GetSchema(common.FileProviderSchemaPath), data) +} + +func (p FileProvider) String() string { + return p.fileName } func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) { diff --git a/src/proxy/provider/provider.go b/src/proxy/provider/provider.go index 54ff119..3f4cafc 100644 --- a/src/proxy/provider/provider.go +++ b/src/proxy/provider/provider.go @@ -2,7 +2,6 @@ package provider import ( "context" - "fmt" "path" "github.com/sirupsen/logrus" @@ -13,7 +12,7 @@ import ( type ( Provider struct { - ProviderImpl + ProviderImpl `json:"-"` name string t ProviderType @@ -30,6 +29,7 @@ type ( // even returns error, routes must be non-nil LoadRoutesImpl() (R.Routes, E.NestedError) OnEvent(event W.Event, routes R.Routes) EventResult + String() string } ProviderType string EventResult struct { @@ -83,8 +83,9 @@ func (p *Provider) GetType() ProviderType { return p.t } -func (p *Provider) String() string { - return fmt.Sprintf("%s-%s", p.t, p.name) +// to work with json marshaller +func (p *Provider) MarshalText() ([]byte, error) { + return []byte(p.String()), nil } func (p *Provider) StartAllRoutes() (res E.NestedError) { @@ -92,7 +93,6 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) { defer errors.To(&res) // start watcher no matter load success or not - p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) go p.watchEvents() nStarted := 0 @@ -153,6 +153,7 @@ func (p *Provider) LoadRoutes() E.NestedError { } func (p *Provider) watchEvents() { + p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) events, errs := p.watcher.Events(p.watcherCtx) l := p.l.WithField("module", "watcher") @@ -160,21 +161,15 @@ func (p *Provider) watchEvents() { select { case <-p.watcherCtx.Done(): return - case event, ok := <-events: - if !ok { // channel closed - return - } + case event := <-events: res := p.OnEvent(event, p.routes) l.Infof("%s event %q", event.Type, event) l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved) if res.err.HasError() { l.Error(res.err) } - case err, ok := <-errs: - if !ok { - return - } - if err.Is(context.Canceled) { + case err := <-errs: + if err == nil || err.Is(context.Canceled) { continue } l.Errorf("watcher error: %s", err) diff --git a/src/proxy/reverse_proxy_mod.go b/src/proxy/reverse_proxy_mod.go index 811e220..518d732 100644 --- a/src/proxy/reverse_proxy_mod.go +++ b/src/proxy/reverse_proxy_mod.go @@ -232,7 +232,7 @@ func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *Revers } return &ReverseProxy{Rewrite: func(pr *ProxyRequest) { rewriteRequestURL(pr.Out, target) - pr.SetXForwarded() + // pr.SetXForwarded() setHeaders(pr.Out) hideHeaders(pr.Out) }, Transport: transport} @@ -348,9 +348,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } outreq.Header.Del("Forwarded") - // outreq.Header.Del("X-Forwarded-For") - // outreq.Header.Del("X-Forwarded-Host") - // outreq.Header.Del("X-Forwarded-Proto") + outreq.Header.Del("X-Forwarded-For") + outreq.Header.Del("X-Forwarded-Host") + outreq.Header.Del("X-Forwarded-Proto") pr := &ProxyRequest{ In: req, diff --git a/src/route/constants.go b/src/route/constants.go index 5be1ebb..46554b5 100644 --- a/src/route/constants.go +++ b/src/route/constants.go @@ -4,5 +4,5 @@ import ( "time" ) -const udpBufferSize = 1500 +const udpBufferSize = 8192 const streamStopListenTimeout = 1 * time.Second diff --git a/src/route/http_route.go b/src/route/http_route.go index 5f796c8..d8f817d 100755 --- a/src/route/http_route.go +++ b/src/route/http_route.go @@ -37,19 +37,23 @@ type ( ) func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { - var trans http.RoundTripper + var trans *http.Transport var regIdleWatcher func() E.NestedError var unregIdleWatcher func() if entry.NoTLSVerify { - trans = transportNoTLS + trans = transportNoTLS.Clone() } else { - trans = transport + trans = transport.Clone() } rp := P.NewReverseProxy(entry.URL, trans, entry) if entry.UseIdleWatcher() { + // allow time for response header up to `WakeTimeout` + if entry.WakeTimeout > trans.ResponseHeaderTimeout { + trans.ResponseHeaderTimeout = entry.WakeTimeout + } regIdleWatcher = func() E.NestedError { watcher, err := idlewatcher.Register(entry) if err.HasError() { @@ -114,6 +118,7 @@ func (r *HTTPRoute) Stop() E.NestedError { if r.unregIdleWatcher != nil { r.unregIdleWatcher() + r.unregIdleWatcher = nil } r.mux = nil @@ -151,13 +156,13 @@ func findMux(host string) (*http.ServeMux, E.NestedError) { } var ( + defaultDialer = net.Dialer{ + Timeout: 60 * time.Second, + KeepAlive: 60 * time.Second, + } transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 60 * time.Second, - KeepAlive: 60 * time.Second, - }).DialContext, - MaxIdleConns: 1000, + Proxy: http.ProxyFromEnvironment, + DialContext: defaultDialer.DialContext, MaxIdleConnsPerHost: 1000, } transportNoTLS = func() *http.Transport { diff --git a/src/route/stream_route.go b/src/route/stream_route.go index 8c492ab..f160108 100755 --- a/src/route/stream_route.go +++ b/src/route/stream_route.go @@ -2,6 +2,7 @@ package route import ( "context" + "errors" "fmt" "sync" "sync/atomic" @@ -129,7 +130,7 @@ func (r *StreamRoute) grHandleConnections() { case conn := <-r.connCh: go func() { err := r.Handle(conn) - if err != nil { + if err != nil && !errors.Is(err, context.Canceled) { r.l.Error(err) } }() diff --git a/src/route/tcp_route.go b/src/route/tcp_route.go index b2fe35e..b786194 100755 --- a/src/route/tcp_route.go +++ b/src/route/tcp_route.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "sync" - "syscall" "time" U "github.com/yusing/go-proxy/utils" @@ -13,14 +12,16 @@ import ( const tcpDialTimeout = 5 * time.Second -type Pipes []*U.BidirectionalPipe +type ( + Pipes []U.BidirectionalPipe -type TCPRoute struct { - *StreamRoute - listener net.Listener - pipe Pipes - mu sync.Mutex -} + TCPRoute struct { + *StreamRoute + listener net.Listener + pipe Pipes + mu sync.Mutex + } +) func NewTCPRoute(base *StreamRoute) StreamImpl { return &TCPRoute{ @@ -59,10 +60,11 @@ func (route *TCPRoute) Handle(c any) error { } route.mu.Lock() - defer route.mu.Unlock() pipe := U.NewBidirectionalPipe(route.ctx, clientConn, serverConn) route.pipe = append(route.pipe, pipe) + + route.mu.Unlock() return pipe.Start() } @@ -72,16 +74,4 @@ func (route *TCPRoute) CloseListeners() { } route.listener.Close() route.listener = nil - for _, pipe := range route.pipe { - if err := pipe.Stop(); err != nil { - switch err { - // target closing connection - // TODO: handle this by fixing utils/io.go - case net.ErrClosed, syscall.EPIPE: - return - default: - route.l.Error(err) - } - } - } } diff --git a/src/route/udp_route.go b/src/route/udp_route.go index 12a9a1c..eb0a430 100755 --- a/src/route/udp_route.go +++ b/src/route/udp_route.go @@ -4,33 +4,34 @@ import ( "fmt" "io" "net" - "sync" - "github.com/yusing/go-proxy/utils" + U "github.com/yusing/go-proxy/utils" + F "github.com/yusing/go-proxy/utils/functional" ) -type UDPRoute struct { - *StreamRoute +type ( + UDPRoute struct { + *StreamRoute - connMap UDPConnMap - connMapMutex sync.Mutex + connMap UDPConnMap - listeningConn *net.UDPConn - targetAddr *net.UDPAddr -} + listeningConn *net.UDPConn + targetAddr *net.UDPAddr + } + UDPConn struct { + src *net.UDPConn + dst *net.UDPConn + U.BidirectionalPipe + } + UDPConnMap = F.Map[string, *UDPConn] +) -type UDPConn struct { - src *net.UDPConn - dst *net.UDPConn - *utils.BidirectionalPipe -} - -type UDPConnMap map[string]*UDPConn +var NewUDPConnMap = F.NewMapOf[string, *UDPConn] func NewUDPRoute(base *StreamRoute) StreamImpl { return &UDPRoute{ StreamRoute: base, - connMap: make(UDPConnMap), + connMap: NewUDPConnMap(), } } @@ -69,28 +70,24 @@ func (route *UDPRoute) Accept() (any, error) { } key := srcAddr.String() - conn, ok := route.connMap[key] + conn, ok := route.connMap.Load(key) if !ok { - route.connMapMutex.Lock() - if conn, ok = route.connMap[key]; !ok { - srcConn, err := net.DialUDP("udp", nil, srcAddr) - if err != nil { - return nil, err - } - dstConn, err := net.DialUDP("udp", nil, route.targetAddr) - if err != nil { - srcConn.Close() - return nil, err - } - conn = &UDPConn{ - srcConn, - dstConn, - utils.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), - } - route.connMap[key] = conn + srcConn, err := net.DialUDP("udp", nil, srcAddr) + if err != nil { + return nil, err } - route.connMapMutex.Unlock() + dstConn, err := net.DialUDP("udp", nil, route.targetAddr) + if err != nil { + srcConn.Close() + return nil, err + } + conn = &UDPConn{ + srcConn, + dstConn, + U.NewBidirectionalPipe(route.ctx, sourceRWCloser{in, dstConn}, sourceRWCloser{in, srcConn}), + } + route.connMap.Store(key, conn) } _, err = conn.dst.Write(buffer[:nRead]) @@ -106,15 +103,15 @@ func (route *UDPRoute) CloseListeners() { route.listeningConn.Close() route.listeningConn = nil } - for _, conn := range route.connMap { + route.connMap.RangeAll(func(_ string, conn *UDPConn) { if err := conn.src.Close(); err != nil { route.l.Errorf("error closing src conn: %s", err) } if err := conn.dst.Close(); err != nil { route.l.Error("error closing dst conn: %s", err) } - } - route.connMap = make(UDPConnMap) + }) + route.connMap.Clear() } type sourceRWCloser struct { diff --git a/src/utils/io.go b/src/utils/io.go index 10f4fd7..cd0ff84 100644 --- a/src/utils/io.go +++ b/src/utils/io.go @@ -3,9 +3,10 @@ package utils import ( "context" "encoding/json" + "errors" "io" "os" - "sync/atomic" + "syscall" E "github.com/yusing/go-proxy/error" ) @@ -16,15 +17,19 @@ type ( Path string } - ReadCloser struct { - ctx context.Context - r io.ReadCloser - closed atomic.Bool + ContextReader struct { + ctx context.Context + io.Reader + } + + ContextWriter struct { + ctx context.Context + io.Writer } Pipe struct { - r ReadCloser - w io.WriteCloser + r ContextReader + w ContextWriter ctx context.Context cancel context.CancelFunc } @@ -35,48 +40,48 @@ type ( } ) -func (r *ReadCloser) Read(p []byte) (int, error) { +func (r *ContextReader) Read(p []byte) (int, error) { select { case <-r.ctx.Done(): return 0, r.ctx.Err() default: - return r.r.Read(p) + return r.Reader.Read(p) } } -func (r *ReadCloser) Close() error { - if r.closed.Load() { - return nil +func (w *ContextWriter) Write(p []byte) (int, error) { + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + return w.Writer.Write(p) } - r.closed.Store(true) - return r.r.Close() } func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { - ctx, cancel := context.WithCancel(ctx) + _, cancel := context.WithCancel(ctx) return &Pipe{ - r: ReadCloser{ctx: ctx, r: r}, - w: w, + r: ContextReader{ctx: ctx, Reader: r}, + w: ContextWriter{ctx: ctx, Writer: w}, ctx: ctx, cancel: cancel, } } -func (p *Pipe) Start() error { - return Copy(p.ctx, p.w, &p.r) +func (p *Pipe) Start() (err error) { + err = Copy(&p.w, &p.r) + switch { + case + // NOTE: ignoring broken pipe and connection reset by peer + errors.Is(err, syscall.EPIPE), + errors.Is(err, syscall.ECONNRESET): + return nil + } + return err } -func (p *Pipe) Stop() error { - p.cancel() - return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error() -} - -func (p *Pipe) Write(b []byte) (int, error) { - return p.w.Write(b) -} - -func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe { - return &BidirectionalPipe{ +func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) BidirectionalPipe { + return BidirectionalPipe{ pSrcDst: NewPipe(ctx, rw1, rw2), pDstSrc: NewPipe(ctx, rw2, rw1), } @@ -89,7 +94,7 @@ func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadClose } } -func (p *BidirectionalPipe) Start() error { +func (p BidirectionalPipe) Start() error { errCh := make(chan error, 2) go func() { errCh <- p.pSrcDst.Start() @@ -97,20 +102,11 @@ func (p *BidirectionalPipe) Start() error { go func() { errCh <- p.pDstSrc.Start() }() - for err := range errCh { - if err != nil { - return err - } - } - return nil + return E.JoinE("bidirectional pipe error", <-errCh, <-errCh).Error() } -func (p *BidirectionalPipe) Stop() error { - return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error() -} - -func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error { - _, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src}) +func Copy(dst *ContextWriter, src *ContextReader) error { + _, err := io.Copy(dst, src) return err } diff --git a/src/utils/schema.go b/src/utils/schema.go index 79a1fa9..dd66849 100644 --- a/src/utils/schema.go +++ b/src/utils/schema.go @@ -4,13 +4,10 @@ import ( "github.com/santhosh-tekuri/jsonschema" ) -var schemaCompiler = func() *jsonschema.Compiler { - c := jsonschema.NewCompiler() - c.Draft = jsonschema.Draft7 - return c -}() - -var schemaStorage = make(map[string]*jsonschema.Schema) +var ( + schemaCompiler = jsonschema.NewCompiler() + schemaStorage = make(map[string]*jsonschema.Schema) +) func GetSchema(path string) *jsonschema.Schema { if schema, ok := schemaStorage[path]; ok { diff --git a/src/watcher/docker_watcher.go b/src/watcher/docker_watcher.go index 9135c95..c3c686e 100644 --- a/src/watcher/docker_watcher.go +++ b/src/watcher/docker_watcher.go @@ -42,11 +42,19 @@ func DockerrFilterContainerName(name string) filters.KeyValuePair { } func NewDockerWatcher(host string) DockerWatcher { - return DockerWatcher{host: host, FieldLogger: logrus.WithField("module", "docker_watcher")} + return DockerWatcher{ + host: host, + FieldLogger: (logrus. + WithField("module", "docker_watcher"). + WithField("host", host))} } func NewDockerWatcherWithClient(client D.Client) DockerWatcher { - return DockerWatcher{client: client, FieldLogger: logrus.WithField("module", "docker_watcher")} + return DockerWatcher{ + client: client, + FieldLogger: (logrus. + WithField("module", "docker_watcher"). + WithField("host", client.DaemonHost()))} } func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.NestedError) { @@ -56,7 +64,6 @@ func (w DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Neste func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerListOptions) (<-chan Event, <-chan E.NestedError) { eventCh := make(chan Event) errCh := make(chan E.NestedError) - started := make(chan struct{}) eventsCtx, eventsCancel := context.WithCancel(ctx) @@ -75,7 +82,7 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList attempts := 0 for { w.client, err = D.ConnectClient(w.host) - if err != nil { + if err == nil { break } attempts++ @@ -89,8 +96,11 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList } } + w.Debugf("client connected") + cEventCh, cErrCh := w.client.Events(eventsCtx, options) - started <- struct{}{} + + w.Debugf("watcher started") for { select { @@ -130,7 +140,6 @@ func (w DockerWatcher) EventsWithOptions(ctx context.Context, options DockerList } } }() - <-started return eventCh, errCh } diff --git a/src/watcher/events/events.go b/src/watcher/events/events.go index 87abcd2..a892f89 100644 --- a/src/watcher/events/events.go +++ b/src/watcher/events/events.go @@ -14,36 +14,64 @@ type ( ActorAttributes map[string]string Action Action } - Action string + Action uint16 EventType string ) const ( - ActionFileModified Action = "modified" - ActionFileCreated Action = "created" - ActionFileDeleted Action = "deleted" + ActionFileModified Action = (1 << iota) + ActionFileCreated + ActionFileDeleted - ActionDockerStartUnpause Action = "start" - ActionDockerStopPause Action = "stop" + ActionContainerCreate + ActionContainerStart + ActionContainerUnpause + ActionContainerKill + ActionContainerStop + ActionContainerPause + ActionContainerDie + + actionContainerWakeMask = ActionContainerCreate | ActionContainerStart | ActionContainerUnpause + actionContainerSleepMask = ActionContainerKill | ActionContainerStop | ActionContainerPause | ActionContainerDie +) + +const ( EventTypeDocker EventType = "docker" EventTypeFile EventType = "file" ) var DockerEventMap = map[dockerEvents.Action]Action{ - dockerEvents.ActionCreate: ActionDockerStartUnpause, - dockerEvents.ActionStart: ActionDockerStartUnpause, - dockerEvents.ActionPause: ActionDockerStartUnpause, - dockerEvents.ActionDie: ActionDockerStopPause, - dockerEvents.ActionStop: ActionDockerStopPause, - dockerEvents.ActionUnPause: ActionDockerStopPause, - dockerEvents.ActionKill: ActionDockerStopPause, + dockerEvents.ActionCreate: ActionContainerCreate, + dockerEvents.ActionStart: ActionContainerStart, + dockerEvents.ActionUnPause: ActionContainerUnpause, + + dockerEvents.ActionKill: ActionContainerKill, + dockerEvents.ActionStop: ActionContainerStop, + dockerEvents.ActionPause: ActionContainerPause, + dockerEvents.ActionDie: ActionContainerDie, } +var dockerActionNameMap = func() (m map[Action]string) { + m = make(map[Action]string, len(DockerEventMap)) + for k, v := range DockerEventMap { + m[v] = string(k) + } + return +}() + func (e Event) String() string { return fmt.Sprintf("%s %s", e.ActorName, e.Action) } -func (a Action) IsDelete() bool { - return a == ActionFileDeleted +func (a Action) String() string { + return dockerActionNameMap[a] +} + +func (a Action) IsContainerWake() bool { + return a&actionContainerWakeMask != 0 +} + +func (a Action) IsContainerSleep() bool { + return a&actionContainerSleepMask != 0 }