graceful shutdown and ref count related

This commit is contained in:
yusing 2024-10-14 10:31:27 +08:00
parent f38b3abdbc
commit 53fa28ae77
4 changed files with 178 additions and 41 deletions

View file

@ -13,8 +13,8 @@ import (
var ( var (
globalCtx, globalCtxCancel = context.WithCancel(context.Background()) globalCtx, globalCtxCancel = context.WithCancel(context.Background())
globalCtxWg sync.WaitGroup taskWg sync.WaitGroup
globalCtxTraceMap = xsync.NewMapOf[*task, struct{}]() tasksMap = xsync.NewMapOf[*task, struct{}]()
) )
type ( type (
@ -38,10 +38,14 @@ func (t *task) Name() string {
return t.name return t.name
} }
// Context returns the context associated with the task. This context is
// canceled when the task is finished.
func (t *task) Context() context.Context { func (t *task) Context() context.Context {
return t.ctx return t.ctx
} }
// Finished marks the task as finished and notifies the global wait group.
// Finished is thread-safe and idempotent.
func (t *task) Finished() { func (t *task) Finished() {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -50,12 +54,22 @@ func (t *task) Finished() {
return return
} }
t.finished = true t.finished = true
if _, ok := globalCtxTraceMap.Load(t); ok { if _, ok := tasksMap.Load(t); ok {
globalCtxWg.Done() taskWg.Done()
globalCtxTraceMap.Delete(t) tasksMap.Delete(t)
} }
logrus.Debugf("task %q finished", t.Name())
} }
// Subtask returns a new subtask with the given name, derived from the receiver's context.
//
// The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be
// canceled immediately.
//
// The returned subtask is safe for concurrent use.
func (t *task) Subtask(format string, args ...interface{}) Task { func (t *task) Subtask(format string, args ...interface{}) Task {
if len(args) > 0 { if len(args) > 0 {
format = fmt.Sprintf(format, args...) format = fmt.Sprintf(format, args...)
@ -67,6 +81,14 @@ func (t *task) Subtask(format string, args ...interface{}) Task {
return sub return sub
} }
// SubtaskWithCancel returns a new subtask with the given name, derived from the receiver's context,
// and a cancel function. The returned subtask is associated with the receiver's context and will be
// automatically registered and deregistered from the global task wait group.
//
// If the receiver's context is already canceled, the returned subtask will be canceled immediately.
//
// The returned cancel function is safe for concurrent use, and can be used to cancel the returned
// subtask at any time.
func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
if len(args) > 0 { if len(args) > 0 {
format = fmt.Sprintf(format, args...) format = fmt.Sprintf(format, args...)
@ -79,7 +101,7 @@ func (t *task) SubtaskWithCancel(format string, args ...interface{}) (Task, cont
return sub, cancel return sub, cancel
} }
func (t *task) Tree(prefix ...string) string { func (t *task) tree(prefix ...string) string {
var sb strings.Builder var sb strings.Builder
var pre string var pre string
if len(prefix) > 0 { if len(prefix) > 0 {
@ -91,7 +113,7 @@ func (t *task) Tree(prefix ...string) string {
if sub.finished { if sub.finished {
continue continue
} }
sb.WriteString(sub.Tree(pre + " ")) sb.WriteString(sub.tree(pre + " "))
} }
return sb.String() return sb.String()
} }
@ -101,11 +123,22 @@ func newSubTask(ctx context.Context, name string) *task {
ctx: ctx, ctx: ctx,
name: name, name: name,
} }
globalCtxTraceMap.Store(t, struct{}{}) tasksMap.Store(t, struct{}{})
globalCtxWg.Add(1) taskWg.Add(1)
return t return t
} }
// NewTask returns a new Task with the given name, derived from the global
// context.
//
// The returned Task is associated with the global context and will be
// automatically registered and deregistered from the global context's wait
// group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is not safe for concurrent use.
func NewTask(format string, args ...interface{}) Task { func NewTask(format string, args ...interface{}) Task {
if len(args) > 0 { if len(args) > 0 {
format = fmt.Sprintf(format, args...) format = fmt.Sprintf(format, args...)
@ -113,6 +146,18 @@ func NewTask(format string, args ...interface{}) Task {
return newSubTask(globalCtx, format) return newSubTask(globalCtx, format)
} }
// NewTaskWithCancel returns a new Task with the given name, derived from the
// global context, and a cancel function. The returned Task is associated with
// the global context and will be automatically registered and deregistered
// from the global task wait group.
//
// If the global context is already canceled, the returned Task will be
// canceled immediately.
//
// The returned Task is safe for concurrent use.
//
// The returned cancel function is safe for concurrent use, and can be used
// to cancel the returned Task at any time.
func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) { func NewTaskWithCancel(format string, args ...interface{}) (Task, context.CancelFunc) {
subCtx, cancel := context.WithCancel(globalCtx) subCtx, cancel := context.WithCancel(globalCtx)
if len(args) > 0 { if len(args) > 0 {
@ -121,6 +166,17 @@ func NewTaskWithCancel(format string, args ...interface{}) (Task, context.Cancel
return newSubTask(subCtx, format), cancel return newSubTask(subCtx, format), cancel
} }
// GlobalTask returns a new Task with the given name, associated with the
// global context.
//
// Unlike NewTask, GlobalTask does not automatically register or deregister
// the Task with the global task wait group. The returned Task is not
// started, but the name is formatted immediately.
//
// This is best used for main task that do not need to wait and
// will create a bunch of subtasks.
//
// The returned Task is safe for concurrent use.
func GlobalTask(format string, args ...interface{}) Task { func GlobalTask(format string, args ...interface{}) Task {
if len(args) > 0 { if len(args) > 0 {
format = fmt.Sprintf(format, args...) format = fmt.Sprintf(format, args...)
@ -131,15 +187,24 @@ func GlobalTask(format string, args ...interface{}) Task {
} }
} }
// CancelGlobalContext cancels the global context, which will cause all tasks
// created by GlobalTask or NewTask to be canceled. This should be called
// before exiting the program to ensure that all tasks are properly cleaned
// up.
func CancelGlobalContext() { func CancelGlobalContext() {
globalCtxCancel() globalCtxCancel()
} }
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
//
// If the timeout is exceeded, it prints a list of all tasks that were
// still running when the timeout was reached, and their current tree
// of subtasks.
func GlobalContextWait(timeout time.Duration) { func GlobalContextWait(timeout time.Duration) {
done := make(chan struct{}) done := make(chan struct{})
after := time.After(timeout) after := time.After(timeout)
go func() { go func() {
globalCtxWg.Wait() taskWg.Wait()
close(done) close(done)
}() }()
for { for {
@ -148,8 +213,8 @@ func GlobalContextWait(timeout time.Duration) {
return return
case <-after: case <-after:
logrus.Println("Timeout waiting for these tasks to finish:") logrus.Println("Timeout waiting for these tasks to finish:")
globalCtxTraceMap.Range(func(t *task, _ struct{}) bool { tasksMap.Range(func(t *task, _ struct{}) bool {
logrus.Println(t.Tree()) logrus.Println(t.tree())
return true return true
}) })
return return

View file

@ -3,24 +3,27 @@ package docker
import ( import (
"net/http" "net/http"
"sync" "sync"
"sync/atomic"
"github.com/docker/cli/cli/connhelper" "github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
U "github.com/yusing/go-proxy/internal/utils"
F "github.com/yusing/go-proxy/internal/utils/functional" F "github.com/yusing/go-proxy/internal/utils/functional"
) )
type Client struct { type (
*client.Client Client = *SharedClient
SharedClient struct {
*client.Client
key string key string
refCount *atomic.Int32 refCount *U.RefCount
l logrus.FieldLogger l logrus.FieldLogger
} }
)
var ( var (
clientMap F.Map[string, Client] = F.NewMapOf[string, Client]() clientMap F.Map[string, Client] = F.NewMapOf[string, Client]()
@ -32,26 +35,34 @@ var (
} }
) )
func (c Client) Connected() bool { func init() {
return c.Client != nil go func() {
task := common.NewTask("close all docker client")
defer task.Finished()
for {
select {
case <-task.Context().Done():
clientMap.RangeAllParallel(func(_ string, c Client) {
c.Client.Close()
})
clientMap.Clear()
return
}
}
}()
}
func (c *SharedClient) Connected() bool {
return c != nil && c.Client != nil
} }
// if the client is still referenced, this is no-op. // if the client is still referenced, this is no-op.
func (c *Client) Close() error { func (c *SharedClient) Close() error {
if c.refCount.Add(-1) > 0 { if !c.Connected() {
return nil return nil
} }
clientMap.Delete(c.key) c.refCount.Sub()
client := c.Client
c.Client = nil
c.l.Debugf("client closed")
if client != nil {
return client.Close()
}
return nil return nil
} }
@ -71,7 +82,7 @@ func ConnectClient(host string) (Client, E.NestedError) {
// check if client exists // check if client exists
if client, ok := clientMap.Load(host); ok { if client, ok := clientMap.Load(host); ok {
client.refCount.Add(1) client.refCount.Add()
return client, nil return client, nil
} }
@ -80,13 +91,13 @@ func ConnectClient(host string) (Client, E.NestedError) {
switch host { switch host {
case "": case "":
return Client{}, E.Invalid("docker host", "empty") return nil, E.Invalid("docker host", "empty")
case common.DockerHostFromEnv: case common.DockerHostFromEnv:
opt = clientOptEnvHost opt = clientOptEnvHost
default: default:
helper, err := E.Check(connhelper.GetConnectionHelper(host)) helper, err := E.Check(connhelper.GetConnectionHelper(host))
if err.HasError() { if err.HasError() {
return Client{}, E.UnexpectedError(err.Error()) return nil, E.UnexpectedError(err.Error())
} }
if helper != nil { if helper != nil {
httpClient := &http.Client{ httpClient := &http.Client{
@ -111,19 +122,29 @@ func ConnectClient(host string) (Client, E.NestedError) {
client, err := E.Check(client.NewClientWithOpts(opt...)) client, err := E.Check(client.NewClientWithOpts(opt...))
if err.HasError() { if err.HasError() {
return Client{}, err return nil, err
} }
c := Client{ c := &SharedClient{
Client: client, Client: client,
key: host, key: host,
refCount: &atomic.Int32{}, refCount: U.NewRefCounter(),
l: logger.WithField("docker_client", client.DaemonHost()), l: logger.WithField("docker_client", client.DaemonHost()),
} }
c.refCount.Add(1)
c.l.Debugf("client connected") c.l.Debugf("client connected")
clientMap.Store(host, c) clientMap.Store(host, c)
go func() {
<-c.refCount.Zero()
clientMap.Delete(c.key)
if c.Client != nil {
c.Client.Close()
c.Client = nil
c.l.Debugf("client closed")
}
}()
return c, nil return c, nil
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/autocert" "github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -20,6 +21,7 @@ type Server struct {
httpStarted bool httpStarted bool
httpsStarted bool httpsStarted bool
startTime time.Time startTime time.Time
task common.Task
} }
type Options struct { type Options struct {
@ -82,6 +84,7 @@ func NewServer(opt Options) (s *Server) {
CertProvider: opt.CertProvider, CertProvider: opt.CertProvider,
http: httpSer, http: httpSer,
https: httpsSer, https: httpsSer,
task: common.GlobalTask("Server " + opt.Name),
} }
} }
@ -111,9 +114,15 @@ func (s *Server) Start() {
s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath())) s.handleErr("https", s.https.ListenAndServeTLS(s.CertProvider.GetCertPath(), s.CertProvider.GetKeyPath()))
}() }()
} }
go func() {
<-s.task.Context().Done()
s.stop()
s.task.Finished()
}()
} }
func (s *Server) Stop() { func (s *Server) stop() {
if s.http == nil && s.https == nil { if s.http == nil && s.https == nil {
return return
} }

View file

@ -0,0 +1,42 @@
package utils
type RefCount struct {
_ NoCopy
refCh chan bool
notifyZero chan struct{}
}
func NewRefCounter() *RefCount {
rc := &RefCount{
refCh: make(chan bool, 1),
notifyZero: make(chan struct{}),
}
go func() {
refCount := uint32(1)
for isAdd := range rc.refCh {
if isAdd {
refCount++
} else {
refCount--
}
if refCount <= 0 {
close(rc.notifyZero)
return
}
}
}()
return rc
}
func (rc *RefCount) Zero() <-chan struct{} {
return rc.notifyZero
}
func (rc *RefCount) Add() {
rc.refCh <- true
}
func (rc *RefCount) Sub() {
rc.refCh <- false
}