diff --git a/group_options.go b/group_options.go index 1344335..94d9679 100644 --- a/group_options.go +++ b/group_options.go @@ -7,6 +7,7 @@ type options struct { preLock bool termOnError bool discardIfFull bool + tresholdSize int } // GroupOption functional option type @@ -34,3 +35,18 @@ func Discard(o *options) { o.discardIfFull = true o.preLock = true // discard implies preemptive } + +// DiscardAfterTreshold works similarly to Discard, but buffers tasks if all goroutines are busy +// until the treshold size of 'active' tasks (i.e. executing and scheduled for execution) is achieved +// If this value is lower than size, it will be ignored and common Discard mode will is used +func DiscardAfterTreshold(tresholdSize int) GroupOption { + return func(o *options) { + o.discardIfFull = true + o.preLock = true + + if tresholdSize < 1 { + tresholdSize = 0 + } + o.tresholdSize = tresholdSize + } +} diff --git a/semaphore_test.go b/semaphore_test.go index 3757d2b..c01d00c 100644 --- a/semaphore_test.go +++ b/semaphore_test.go @@ -41,16 +41,16 @@ func TestSemaphore(t *testing.T) { // if number of locks are less than capacity, all should be acquired if tt.lockTimes <= tt.capacity { - assert.Equal(t, int32(tt.lockTimes), atomic.LoadInt32(&locks)) + assert.Equal(t, tt.lockTimes, int(atomic.LoadInt32(&locks))) wg.Wait() return } // if number of locks exceed capacity, it should hang after reaching the capacity - assert.Equal(t, int32(tt.capacity), atomic.LoadInt32(&locks)) + assert.Equal(t, tt.capacity, int(atomic.LoadInt32(&locks))) sema.Unlock() time.Sleep(10 * time.Millisecond) // after unlock, it should be able to acquire another lock - assert.Equal(t, int32(tt.capacity+1), atomic.LoadInt32(&locks)) + assert.Equal(t, tt.capacity+1, int(atomic.LoadInt32(&locks))) wg.Wait() }) } @@ -81,7 +81,7 @@ func TestSemaphore_TryLock(t *testing.T) { } // Check the acquired locks, it should not exceed capacity. - assert.Equal(t, int32(tt.expectedLocks), atomic.LoadInt32(&locks)) + assert.Equal(t, tt.expectedLocks, int(atomic.LoadInt32(&locks))) }) } } diff --git a/sizedgroup.go b/sizedgroup.go index b73f77a..4dc9e50 100644 --- a/sizedgroup.go +++ b/sizedgroup.go @@ -10,62 +10,110 @@ import ( // SizedGroup interface enforces constructor usage and doesn't allow direct creation of sizedGroup type SizedGroup struct { options - wg sync.WaitGroup - sema Locker + wg sync.WaitGroup + workers chan struct{} + scheduledJobs chan struct{} + jobQueue chan func(ctx context.Context) + workersMutex sync.Mutex } // NewSizedGroup makes wait group with limited size alive goroutines func NewSizedGroup(size int, opts ...GroupOption) *SizedGroup { - res := SizedGroup{sema: NewSemaphore(size)} + if size < 1 { + size = 1 + } + res := SizedGroup{workers: make(chan struct{}, size)} res.options.ctx = context.Background() for _, opt := range opts { opt(&res.options) } + + // queue size either equal to number of workers or larger, otherwise does not make sense + queueSize := size + if res.tresholdSize > size { + queueSize = res.tresholdSize + } + + res.jobQueue = make(chan func(ctx context.Context), queueSize) + res.scheduledJobs = make(chan struct{}, queueSize) return &res } // Go calls the given function in a new goroutine. // Every call will be unblocked, but some goroutines may wait if semaphore locked. func (g *SizedGroup) Go(fn func(ctx context.Context)) { - canceled := func() bool { - select { - case <-g.ctx.Done(): - return true - default: - return false - } + if g.canceled() { + return } - if canceled() { + g.wg.Add(1) + if !g.preLock { + go func() { + defer g.wg.Done() + if g.canceled() { + return + } + g.scheduledJobs <- struct{}{} + fn(g.ctx) + <-g.scheduledJobs + }() return } - if g.preLock { - lockOk := g.sema.TryLock() - if !lockOk && g.discardIfFull { - // lock failed and discardIfFull is set, discard this goroutine + toRun := func(job func(ctx context.Context)) { + defer g.wg.Done() + if g.canceled() { return } - if !lockOk && !g.discardIfFull { - g.sema.Lock() // make sure we have block until lock is acquired - } + job(g.ctx) + <-g.scheduledJobs } - g.wg.Add(1) - go func() { - defer g.wg.Done() - - if canceled() { - return + startWorkerIfNeeded := func() { + g.workersMutex.Lock() + select { + case g.workers <- struct{}{}: + g.workersMutex.Unlock() + go func() { + for { + select { + case job := <-g.jobQueue: + toRun(job) + default: + g.workersMutex.Lock() + select { + case job := <-g.jobQueue: + g.workersMutex.Unlock() + toRun(job) + continue + default: + <-g.workers + g.workersMutex.Unlock() + } + return + } + } + }() + default: + g.workersMutex.Unlock() } + } - if !g.preLock { - g.sema.Lock() + if g.discardIfFull { + select { + case g.scheduledJobs <- struct{}{}: + g.jobQueue <- fn + startWorkerIfNeeded() + default: + g.wg.Done() } - fn(g.ctx) - g.sema.Unlock() - }() + return + } + + g.scheduledJobs <- struct{}{} + g.jobQueue <- fn + startWorkerIfNeeded() } // Wait blocks until the SizedGroup counter is zero. @@ -73,3 +121,12 @@ func (g *SizedGroup) Go(fn func(ctx context.Context)) { func (g *SizedGroup) Wait() { g.wg.Wait() } + +func (g *SizedGroup) canceled() bool { + select { + case <-g.ctx.Done(): + return true + default: + return false + } +} diff --git a/sizedgroup_test.go b/sizedgroup_test.go index 581d7e0..98eba0a 100644 --- a/sizedgroup_test.go +++ b/sizedgroup_test.go @@ -17,7 +17,7 @@ func TestSizedGroup(t *testing.T) { var c uint32 for i := 0; i < 1000; i++ { - swg.Go(func(ctx context.Context) { + swg.Go(func(context.Context) { time.Sleep(5 * time.Millisecond) atomic.AddUint32(&c, 1) }) @@ -32,7 +32,7 @@ func TestSizedGroup_Discard(t *testing.T) { var c uint32 for i := 0; i < 100; i++ { - swg.Go(func(ctx context.Context) { + swg.Go(func(context.Context) { time.Sleep(5 * time.Millisecond) atomic.AddUint32(&c, 1) }) @@ -42,12 +42,27 @@ func TestSizedGroup_Discard(t *testing.T) { assert.Equal(t, uint32(10), c, fmt.Sprintf("%d, not all routines have been executed", c)) } +func TestSizedGroup_WithWrongSizeValuePassed(t *testing.T) { + swg := NewSizedGroup(0, Discard) + var c uint32 + + for i := 0; i < 100; i++ { + swg.Go(func(context.Context) { + time.Sleep(5 * time.Millisecond) + atomic.AddUint32(&c, 1) + }) + } + assert.True(t, runtime.NumGoroutine() < 6, "goroutines %d", runtime.NumGoroutine()) + swg.Wait() + assert.Equal(t, uint32(1), c, fmt.Sprintf("%d, wrong number of routines has been executed", c)) +} + func TestSizedGroup_Preemptive(t *testing.T) { swg := NewSizedGroup(10, Preemptive) var c uint32 for i := 0; i < 100; i++ { - swg.Go(func(ctx context.Context) { + swg.Go(func(context.Context) { time.Sleep(5 * time.Millisecond) atomic.AddUint32(&c, 1) }) @@ -57,14 +72,14 @@ func TestSizedGroup_Preemptive(t *testing.T) { assert.Equal(t, uint32(100), c, fmt.Sprintf("%d, not all routines have been executed", c)) } -func TestSizedGroup_Canceled(t *testing.T) { +func TestSizedGroup_CanceledPreemtiveMode(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() swg := NewSizedGroup(10, Preemptive, Context(ctx)) var c uint32 for i := 0; i < 100; i++ { - swg.Go(func(ctx context.Context) { + swg.Go(func(context.Context) { select { case <-ctx.Done(): return @@ -77,6 +92,71 @@ func TestSizedGroup_Canceled(t *testing.T) { assert.True(t, c < 100) } +func TestSizedGroup_CanceledNonPreemptiveMode(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + swg := NewSizedGroup(10, Context(ctx)) + var c uint32 + + for i := 0; i < 2000; i++ { + swg.Go(func(context.Context) { + select { + case <-ctx.Done(): + return + case <-time.After(5 * time.Millisecond): + } + atomic.AddUint32(&c, 1) + }) + } + swg.Wait() + assert.True(t, c < 25) +} + +func TestSizedGroup_DiscardAfterTreshold(t *testing.T) { + swg := NewSizedGroup(10, DiscardAfterTreshold(20)) + var c uint32 + + for i := 0; i < 100; i++ { + swg.Go(func(context.Context) { + time.Sleep(5 * time.Millisecond) + atomic.AddUint32(&c, 1) + }) + } + assert.True(t, runtime.NumGoroutine() < 15, "goroutines %d", runtime.NumGoroutine()) + swg.Wait() + assert.Equal(t, uint32(20), c, fmt.Sprintf("%d, wrong number of routines have been executed", c)) +} + +func TestSizedGroup_DiscardAfterTreshold_WithNegativeTreshold(t *testing.T) { + swg := NewSizedGroup(10, DiscardAfterTreshold(-1)) + var c uint32 + + for i := 0; i < 100; i++ { + swg.Go(func(context.Context) { + time.Sleep(5 * time.Millisecond) + atomic.AddUint32(&c, 1) + }) + } + assert.True(t, runtime.NumGoroutine() < 15, "goroutines %d", runtime.NumGoroutine()) + swg.Wait() + assert.Equal(t, uint32(10), c, fmt.Sprintf("%d, wrong number of routines have been executed", c)) +} + +func TestSizedGroup_DiscardAfterTreshold_WithTresholdNotAboveSize(t *testing.T) { + swg := NewSizedGroup(10, DiscardAfterTreshold(10)) + var c uint32 + + for i := 0; i < 100; i++ { + swg.Go(func(context.Context) { + time.Sleep(5 * time.Millisecond) + atomic.AddUint32(&c, 1) + }) + } + assert.True(t, runtime.NumGoroutine() < 15, "goroutines %d", runtime.NumGoroutine()) + swg.Wait() + assert.Equal(t, uint32(10), c, fmt.Sprintf("%d, wrong number of routines have been executed", c)) +} + // illustrates the use of a SizedGroup for concurrent, limited execution of goroutines. func ExampleSizedGroup_go() { @@ -84,7 +164,7 @@ func ExampleSizedGroup_go() { var c uint32 for i := 0; i < 1000; i++ { - grp.Go(func(ctx context.Context) { // Go call is non-blocking, like regular go statement + grp.Go(func(context.Context) { // Go call is non-blocking, like regular go statement // do some work in 10 goroutines in parallel atomic.AddUint32(&c, 1) time.Sleep(10 * time.Millisecond)