Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear pending tasks in the worker when the context is canceled to avoid deadlocks in StopAndWait when tasks are queued for the worker. #62

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pond.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,13 @@ func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) {
// Terminate all workers & purger goroutine
p.contextCancel()

// Wait for all workers & purger goroutine to exit
p.workersWaitGroup.Wait()

// close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made)
p.tasksCloseOnce.Do(func() {
close(p.tasks)
})

// Wait for all workers & purger goroutine to exit
p.workersWaitGroup.Wait()
alitto marked this conversation as resolved.
Show resolved Hide resolved
}

// purge represents the work done by the purger goroutine
Expand Down Expand Up @@ -420,7 +420,7 @@ func (p *WorkerPool) maybeStartWorker(firstTask func()) bool {
}

// Launch worker goroutine
go worker(p.context, &p.workersWaitGroup, firstTask, p.tasks, p.executeTask)
go worker(p.context, &p.workersWaitGroup, firstTask, p.tasks, p.executeTask, &p.tasksWaitGroup)

return true
}
Expand Down
41 changes: 41 additions & 0 deletions pond_blackbox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,47 @@ func TestSubmitWithContext(t *testing.T) {
assertEqual(t, int32(0), atomic.LoadInt32(&doneCount))
}

func TestSubmitWithContextCancelWithIdleTasks(t *testing.T) {

ctx, cancel := context.WithCancel(context.Background())

pool := pond.New(1, 5, pond.Context(ctx))

var doneCount, taskCount int32

// Submit a long-running, cancellable task
pool.Submit(func() {
atomic.AddInt32(&taskCount, 1)
select {
case <-ctx.Done():
return
case <-time.After(10 * time.Minute):
atomic.AddInt32(&doneCount, 1)
return
}
})

// Submit a long-running, cancellable task
pool.Submit(func() {
atomic.AddInt32(&taskCount, 1)
select {
case <-ctx.Done():
return
case <-time.After(10 * time.Minute):
atomic.AddInt32(&doneCount, 1)
return
}
})

// Cancel the context
cancel()

pool.StopAndWait()

assertEqual(t, int32(1), atomic.LoadInt32(&taskCount))
assertEqual(t, int32(0), atomic.LoadInt32(&doneCount))
}

func TestConcurrentStopAndWait(t *testing.T) {

pool := pond.New(1, 5)
Expand Down
5 changes: 5 additions & 0 deletions pond_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func TestPurgeAfterPoolStopped(t *testing.T) {
pool.SubmitAndWait(func() {
atomic.AddInt32(&doneCount, 1)
})

time.Sleep(10 * time.Millisecond)

assertEqual(t, int32(1), atomic.LoadInt32(&doneCount))
assertEqual(t, 1, pool.RunningWorkers())

Expand All @@ -59,6 +62,8 @@ func TestPurgeDuringSubmit(t *testing.T) {
atomic.AddInt32(&doneCount, 1)
})

time.Sleep(10 * time.Millisecond)

assertEqual(t, 1, pool.IdleWorkers())

// Stop an idle worker right before submitting another task
Expand Down
33 changes: 25 additions & 8 deletions worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

// worker represents a worker goroutine
func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func(), tasks <-chan func(), taskExecutor func(func(), bool)) {
func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func(), tasks <-chan func(), taskExecutor func(func(), bool), taskWaitGroup *sync.WaitGroup) {

// If provided, execute the first task immediately, before listening to the tasks channel
if firstTask != nil {
Expand All @@ -20,16 +20,33 @@ func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func()
for {
select {
case <-context.Done():
// Pool context was cancelled, exit
// Pool context was cancelled, empty tasks channel and exit
drainTasks(tasks, taskWaitGroup)
return
case task, ok := <-tasks:
if task == nil || !ok {
// We have received a signal to exit
return
}
// Prioritize context.Done statement (https://stackoverflow.com/questions/46200343/force-priority-of-go-select-statement)
select {
case <-context.Done():
if task != nil && ok {
// We have received a task, ignore it
taskWaitGroup.Done()
}
default:
if task == nil || !ok {
// We have received a signal to exit
return
}

// We have received a task, execute it
taskExecutor(task, false)
// We have received a task, execute it
taskExecutor(task, false)
alitto marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

// drainPendingTasks discards queued tasks and decrements the corresponding wait group
func drainTasks(tasks <-chan func(), tasksWaitGroup *sync.WaitGroup) {
for _ = range tasks {
tasksWaitGroup.Done()
}
}
Loading