From 53fa28ae771727c5f0aefdce17f806839eec11ad Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 14 Oct 2024 10:31:27 +0800 Subject: [PATCH] graceful shutdown and ref count related --- internal/common/task.go | 89 ++++++++++++++++++++++++++++++++----- internal/docker/client.go | 77 ++++++++++++++++++++------------ internal/server/server.go | 11 ++++- internal/utils/ref_count.go | 42 +++++++++++++++++ 4 files changed, 178 insertions(+), 41 deletions(-) create mode 100644 internal/utils/ref_count.go diff --git a/internal/common/task.go b/internal/common/task.go index 1f4120d..6e0e61b 100644 --- a/internal/common/task.go +++ b/internal/common/task.go @@ -13,8 +13,8 @@ import ( var ( globalCtx, globalCtxCancel = context.WithCancel(context.Background()) - globalCtxWg sync.WaitGroup - globalCtxTraceMap = xsync.NewMapOf[*task, struct{}]() + taskWg sync.WaitGroup + tasksMap = xsync.NewMapOf[*task, struct{}]() ) type ( @@ -38,10 +38,14 @@ func (t *task) Name() string { return t.name } +// Context returns the context associated with the task. This context is +// canceled when the task is finished. func (t *task) Context() context.Context { return t.ctx } +// Finished marks the task as finished and notifies the global wait group. +// Finished is thread-safe and idempotent. func (t *task) Finished() { t.mu.Lock() defer t.mu.Unlock() @@ -50,12 +54,22 @@ func (t *task) Finished() { return } t.finished = true - if _, ok := globalCtxTraceMap.Load(t); ok { - globalCtxWg.Done() - globalCtxTraceMap.Delete(t) + if _, ok := tasksMap.Load(t); ok { + taskWg.Done() + tasksMap.Delete(t) } + logrus.Debugf("task %q finished", t.Name()) } +// Subtask returns a new subtask with the given name, derived from the receiver's context. +// +// The returned subtask is associated with the receiver's context and will be +// automatically registered and deregistered from the global task wait group. +// +// If the receiver's context is already canceled, the returned subtask will be +// canceled immediately. +// +// The returned subtask is safe for concurrent use. func (t *task) Subtask(format string, args ...interface{}) Task { if len(args) > 0 { format = fmt.Sprintf(format, args...) @@ -67,6 +81,14 @@ func (t *task) Subtask(format string, args ...interface{}) Task { return sub } +// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context, +// and a cancel function. The returned subtask is associated with the receiver's context and will be +// automatically registered and deregistered from the global task wait group. +// +// If the receiver's context is already canceled, the returned subtask will be canceled immediately. +// +// The returned cancel function is safe for concurrent use, and can be used to cancel the returned +// subtask at any time. func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { if len(args) > 0 { format = fmt.Sprintf(format, args...) @@ -79,7 +101,7 @@ func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, cont return sub, cancel } -func (t *task) Tree(prefix ...string) string { +func (t *task) tree(prefix ...string) string { var sb strings.Builder var pre string if len(prefix) > 0 { @@ -91,7 +113,7 @@ func (t *task) Tree(prefix ...string) string { if sub.finished { continue } - sb.WriteString(sub.Tree(pre + " ")) + sb.WriteString(sub.tree(pre + " ")) } return sb.String() } @@ -101,11 +123,22 @@ func newSubTask(ctx context.Context, name string) *task { ctx: ctx, name: name, } - globalCtxTraceMap.Store(t, struct{}{}) - globalCtxWg.Add(1) + tasksMap.Store(t, struct{}{}) + taskWg.Add(1) return t } +// NewTask returns a new Task with the given name, derived from the global +// context. +// +// The returned Task is associated with the global context and will be +// automatically registered and deregistered from the global context's wait +// group. +// +// If the global context is already canceled, the returned Task will be +// canceled immediately. +// +// The returned Task is not safe for concurrent use. func NewTask(format string, args ...interface{}) Task { if len(args) > 0 { format = fmt.Sprintf(format, args...) @@ -113,6 +146,18 @@ func NewTask(format string, args ...interface{}) Task { return newSubTask(globalCtx, format) } +// NewTaskWithCancel returns a new Task with the given name, derived from the +// global context, and a cancel function. The returned Task is associated with +// the global context and will be automatically registered and deregistered +// from the global task wait group. +// +// If the global context is already canceled, the returned Task will be +// canceled immediately. +// +// The returned Task is safe for concurrent use. +// +// The returned cancel function is safe for concurrent use, and can be used +// to cancel the returned Task at any time. func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { subCtx, cancel := context.WithCancel(globalCtx) if len(args) > 0 { @@ -121,6 +166,17 @@ func NewTaskWithCancel(format string, args ...interface{}) (Task, context.Cancel return newSubTask(subCtx, format), cancel } +// GlobalTask returns a new Task with the given name, associated with the +// global context. +// +// Unlike NewTask, GlobalTask does not automatically register or deregister +// the Task with the global task wait group. The returned Task is not +// started, but the name is formatted immediately. +// +// This is best used for main task that do not need to wait and +// will create a bunch of subtasks. +// +// The returned Task is safe for concurrent use. func GlobalTask(format string, args ...interface{}) Task { if len(args) > 0 { format = fmt.Sprintf(format, args...) @@ -131,15 +187,24 @@ func GlobalTask(format string, args ...interface{}) Task { } } +// CancelGlobalContext cancels the global context, which will cause all tasks +// created by GlobalTask or NewTask to be canceled. This should be called +// before exiting the program to ensure that all tasks are properly cleaned +// up. func CancelGlobalContext() { globalCtxCancel() } +// GlobalContextWait waits for all tasks to finish, up to the given timeout. +// +// If the timeout is exceeded, it prints a list of all tasks that were +// still running when the timeout was reached, and their current tree +// of subtasks. func GlobalContextWait(timeout time.Duration) { done := make(chan struct{}) after := time.After(timeout) go func() { - globalCtxWg.Wait() + taskWg.Wait() close(done) }() for { @@ -148,8 +213,8 @@ func GlobalContextWait(timeout time.Duration) { return case <-after: logrus.Println("Timeout waiting for these tasks to finish:") - globalCtxTraceMap.Range(func(t *task, _ struct{}) bool { - logrus.Println(t.Tree()) + tasksMap.Range(func(t *task, _ struct{}) bool { + logrus.Println(t.tree()) return true }) return diff --git a/internal/docker/client.go b/internal/docker/client.go index 5f28fd1..a07a27e 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -3,24 +3,27 @@ package docker import ( "net/http" "sync" - "sync/atomic" "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) -type Client struct { - *client.Client +type ( + Client = *SharedClient + SharedClient struct { + *client.Client - key string - refCount *atomic.Int32 + key string + refCount *U.RefCount - l logrus.FieldLogger -} + l logrus.FieldLogger + } +) var ( clientMap F.Map[string, Client] = F.NewMapOf[string, Client]() @@ -32,26 +35,34 @@ var ( } ) -func (c Client) Connected() bool { - return c.Client != nil +func init() { + go func() { + task := common.NewTask("close all docker client") + defer task.Finished() + for { + select { + case <-task.Context().Done(): + clientMap.RangeAllParallel(func(_ string, c Client) { + c.Client.Close() + }) + clientMap.Clear() + return + } + } + }() +} + +func (c *SharedClient) Connected() bool { + return c != nil && c.Client != nil } // if the client is still referenced, this is no-op. -func (c *Client) Close() error { - if c.refCount.Add(-1) > 0 { +func (c *SharedClient) Close() error { + if !c.Connected() { return nil } - clientMap.Delete(c.key) - - client := c.Client - c.Client = nil - - c.l.Debugf("client closed") - - if client != nil { - return client.Close() - } + c.refCount.Sub() return nil } @@ -71,7 +82,7 @@ func ConnectClient(host string) (Client, E.NestedError) { // check if client exists if client, ok := clientMap.Load(host); ok { - client.refCount.Add(1) + client.refCount.Add() return client, nil } @@ -80,13 +91,13 @@ func ConnectClient(host string) (Client, E.NestedError) { switch host { case "": - return Client{}, E.Invalid("docker host", "empty") + return nil, E.Invalid("docker host", "empty") case common.DockerHostFromEnv: opt = clientOptEnvHost default: helper, err := E.Check(connhelper.GetConnectionHelper(host)) if err.HasError() { - return Client{}, E.UnexpectedError(err.Error()) + return nil, E.UnexpectedError(err.Error()) } if helper != nil { httpClient := &http.Client{ @@ -111,19 +122,29 @@ func ConnectClient(host string) (Client, E.NestedError) { client, err := E.Check(client.NewClientWithOpts(opt...)) if err.HasError() { - return Client{}, err + return nil, err } - c := Client{ + c := &SharedClient{ Client: client, key: host, - refCount: &atomic.Int32{}, + refCount: U.NewRefCounter(), l: logger.WithField("docker_client", client.DaemonHost()), } - c.refCount.Add(1) c.l.Debugf("client connected") clientMap.Store(host, c) + + go func() { + <-c.refCount.Zero() + clientMap.Delete(c.key) + + if c.Client != nil { + c.Client.Close() + c.Client = nil + c.l.Debugf("client closed") + } + }() return c, nil } diff --git a/internal/server/server.go b/internal/server/server.go index fb94634..44e349b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal/autocert" + "github.com/yusing/go-proxy/internal/common" "golang.org/x/net/context" ) @@ -20,6 +21,7 @@ type Server struct { httpStarted bool httpsStarted bool startTime time.Time + task common.Task } type Options struct { @@ -82,6 +84,7 @@ func NewServer(opt Options) (s *Server) { CertProvider: opt.CertProvider, http: httpSer, https: httpsSer, + task: common.GlobalTask("Server " + opt.Name), } } @@ -111,9 +114,15 @@ func (s *Server) Start() { s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath())) }() } + + go func() { + <-s.task.Context().Done() + s.stop() + s.task.Finished() + }() } -func (s *Server) Stop() { +func (s *Server) stop() { if s.http == nil && s.https == nil { return } diff --git a/internal/utils/ref_count.go b/internal/utils/ref_count.go new file mode 100644 index 0000000..40a785a --- /dev/null +++ b/internal/utils/ref_count.go @@ -0,0 +1,42 @@ +package utils + +type RefCount struct { + _ NoCopy + + refCh chan bool + notifyZero chan struct{} +} + +func NewRefCounter() *RefCount { + rc := &RefCount{ + refCh: make(chan bool, 1), + notifyZero: make(chan struct{}), + } + go func() { + refCount := uint32(1) + for isAdd := range rc.refCh { + if isAdd { + refCount++ + } else { + refCount-- + } + if refCount <= 0 { + close(rc.notifyZero) + return + } + } + }() + return rc +} + +func (rc *RefCount) Zero() <-chan struct{} { + return rc.notifyZero +} + +func (rc *RefCount) Add() { + rc.refCh <- true +} + +func (rc *RefCount) Sub() { + rc.refCh <- false +}