From af14966b09ea2bbd9fa3dbaa71c1c581b6e814c1 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 2 Jan 2025 09:59:31 +0800 Subject: [PATCH] rewrite and fix reference counter --- internal/utils/ref_count.go | 41 +++++++++-------- internal/utils/ref_count_test.go | 77 ++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 21 deletions(-) create mode 100644 internal/utils/ref_count_test.go diff --git a/internal/utils/ref_count.go b/internal/utils/ref_count.go index 40a785a..a61a68e 100644 --- a/internal/utils/ref_count.go +++ b/internal/utils/ref_count.go @@ -1,42 +1,41 @@ package utils +import ( + "sync" + "sync/atomic" +) + type RefCount struct { _ NoCopy - refCh chan bool - notifyZero chan struct{} + mu sync.Mutex + cond *sync.Cond + refCount uint32 + zeroCh chan struct{} } func NewRefCounter() *RefCount { rc := &RefCount{ - refCh: make(chan bool, 1), - notifyZero: make(chan struct{}), + refCount: 1, + zeroCh: make(chan struct{}), } - go func() { - refCount := uint32(1) - for isAdd := range rc.refCh { - if isAdd { - refCount++ - } else { - refCount-- - } - if refCount <= 0 { - close(rc.notifyZero) - return - } - } - }() + rc.cond = sync.NewCond(&rc.mu) return rc } func (rc *RefCount) Zero() <-chan struct{} { - return rc.notifyZero + return rc.zeroCh } func (rc *RefCount) Add() { - rc.refCh <- true + atomic.AddUint32(&rc.refCount, 1) } func (rc *RefCount) Sub() { - rc.refCh <- false + if atomic.AddUint32(&rc.refCount, ^uint32(0)) == 0 { + rc.mu.Lock() + close(rc.zeroCh) + rc.cond.Broadcast() + rc.mu.Unlock() + } } diff --git a/internal/utils/ref_count_test.go b/internal/utils/ref_count_test.go new file mode 100644 index 0000000..9d0a2aa --- /dev/null +++ b/internal/utils/ref_count_test.go @@ -0,0 +1,77 @@ +package utils + +import ( + "sync" + "testing" + "time" +) + +func TestRefCounter_AddSub(t *testing.T) { + rc := NewRefCounter() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + rc.Add() + }() + + go func() { + defer wg.Done() + rc.Sub() + }() + + wg.Wait() + + select { + case <-rc.Zero(): + // Expected behavior + case <-time.After(1 * time.Second): + t.Fatal("Expected Zero channel to close, but it didn't") + } +} + +func TestRefCounter_MultipleAddSub(t *testing.T) { + rc := NewRefCounter() + + var wg sync.WaitGroup + numAdds := 5 + numSubs := 5 + wg.Add(numAdds + numSubs) + + for range numAdds { + go func() { + defer wg.Done() + rc.Add() + }() + } + + for range numSubs { + go func() { + defer wg.Done() + rc.Sub() + }() + } + + wg.Wait() + + select { + case <-rc.Zero(): + // Expected behavior + case <-time.After(1 * time.Second): + t.Fatal("Expected Zero channel to close, but it didn't") + } +} + +func TestRefCounter_ZeroInitially(t *testing.T) { + rc := NewRefCounter() + rc.Sub() // Bring count to zero + + select { + case <-rc.Zero(): + // Expected behavior + case <-time.After(1 * time.Second): + t.Fatal("Expected Zero channel to close, but it didn't") + } +}