diff --git a/internal/utils/ref_count.go b/internal/utils/ref_count.go index ac6eb01..782783a 100644 --- a/internal/utils/ref_count.go +++ b/internal/utils/ref_count.go @@ -24,11 +24,31 @@ func (rc *RefCount) Zero() <-chan struct{} { } func (rc *RefCount) Add() { - atomic.AddUint32(&rc.refCount, 1) + // We add before checking to ensure proper ordering + newV := atomic.AddUint32(&rc.refCount, 1) + if newV == 1 { + // If it was 0 before we added, that means we're incrementing after a close + // This is a programming error + panic("RefCount.Add() called after count reached zero") + } } func (rc *RefCount) Sub() { - if atomic.AddUint32(&rc.refCount, ^uint32(0)) == 0 { - close(rc.zeroCh) + // First read the current value + for { + current := atomic.LoadUint32(&rc.refCount) + if current == 0 { + // Already at zero, channel should be closed + return + } + + // Try to decrement, but only if the value hasn't changed + if atomic.CompareAndSwapUint32(&rc.refCount, current, current-1) { + if current == 1 { // Was this the last reference? + close(rc.zeroCh) + } + return + } + // If CAS failed, someone else modified the count, try again } } diff --git a/internal/utils/ref_count_test.go b/internal/utils/ref_count_test.go index c147638..d6e64cd 100644 --- a/internal/utils/ref_count_test.go +++ b/internal/utils/ref_count_test.go @@ -4,6 +4,8 @@ import ( "sync" "testing" "time" + + . "github.com/yusing/go-proxy/internal/utils/testing" ) func TestRefCounterAddSub(t *testing.T) { @@ -12,18 +14,16 @@ func TestRefCounterAddSub(t *testing.T) { var wg sync.WaitGroup wg.Add(2) - go func() { - defer wg.Done() - rc.Add() - }() - - go func() { - defer wg.Done() - rc.Sub() - rc.Sub() - }() + rc.Add() + for range 2 { + go func() { + defer wg.Done() + rc.Sub() + }() + } wg.Wait() + ExpectEqual(t, int(rc.refCount), 0) select { case <-rc.Zero(): @@ -39,7 +39,7 @@ func TestRefCounterMultipleAddSub(t *testing.T) { var wg sync.WaitGroup numAdds := 5 numSubs := 5 - wg.Add(numAdds + numSubs) + wg.Add(numAdds) for range numAdds { go func() { @@ -47,17 +47,20 @@ func TestRefCounterMultipleAddSub(t *testing.T) { rc.Add() }() } + wg.Wait() + ExpectEqual(t, int(rc.refCount), numAdds+1) + wg.Add(numSubs) for range numSubs { go func() { defer wg.Done() rc.Sub() - rc.Sub() }() } - wg.Wait() + ExpectEqual(t, int(rc.refCount), numAdds+1-numSubs) + rc.Sub() select { case <-rc.Zero(): // Expected behavior