Skip to content

Commit

Permalink
Simplify the implementation of the Collector.
Browse files Browse the repository at this point in the history
Instead of maintaining a separate goroutine to synchronize delivery of values,
rework the collector to use a plain sync.Mutex.

This:

- Greatly simplifies the code (with one exception, noted below).

- Eliminates the need for a separate goroutine to service values. Each task now
  handles its own service, mediated by the collector. That, in turn:

- Eliminates the need to Wait for the Collector: Once all the goroutines
  running tasks in the collector have exited, the state is fully settled.
  The Wait method is now a no-op, and is marked as deprecated.

In addition, add a new Report method, replacing Stream. Instead of a channel,
tasks using this method accepts a report function that sends values to the
collector.  The report function ensures control does not return to the task
until the reported value has been serviced, which allows tasks to ensure they
do not exit until all their values have been addressed.

The Stream method still works, but is deprecated. To preserve its interface,
each Stream call now spins up a new goroutine to service the values from its
task. This is wasteful, but easily replaced by switching to Report.

Co-Authored-By: David Anderson <dave@natulte.net>
  • Loading branch information
creachadair and danderson committed Mar 19, 2024
1 parent ebdb7e5 commit 3287b54
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 61 deletions.
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

A `*taskgroup.Group` represents a group of goroutines working on related tasks.
New tasks can be added to the group at will, and the caller can wait until all
tasks are complete. Errors are automatically collected and delivered to a
user-provided callback in a single goroutine. This does not replace the full
tasks are complete. Errors are automatically collected and delivered
synchronously to a user-provided callback. This does not replace the full
generality of Go's built-in features, but it simplifies some of the plumbing
for common concurrent tasks.

Expand Down Expand Up @@ -229,7 +229,7 @@ start(task4) // blocks until one of the previous tasks is finished

## Solo Tasks

In some cases it is useful to start a solo background task to handle an
In some cases it is useful to start a single background task to handle an
isolated concern. For example, suppose we want to read a file into a buffer
while we take care of some other work. Rather than creating a whole group for
a single goroutine, we can create a solo task using the `Go` constructor.
Expand Down Expand Up @@ -270,13 +270,11 @@ var sum int
c := taskgroup.NewCollector(func(v int) { sum += v })
```

Internally, a `Collector` wraps a solo task and a channel to receive results.

The `Task` and `NoError` methods of the collector `c` can then be used to wrap
a function that reports a value. If the function reports an error, that error
is returned from the task as usual. Otherwise, its non-error value is given to
the callback. As in the above example, calls to the function are serialized so
that it is safe to access state without additional locking:
The `Task`, `NoError`, and `Report` methods of `c` wrap a function that yields
a value into a task. If the function reports an error, that error is returned
from the task as usual. Otherwise, its non-error value is given to the
accumulator callback. As in the above example, calls to the function are
serialized so that it is safe to access state without additional locking:

```go
// Report an error, no value for the collector.
Expand All @@ -291,14 +289,21 @@ g.Go(c.Task(func() (int, error) {

// Report a random integer to the collector.
g.Go(c.NoError(func() int { return rand.Intn(1000) })

// Report multiple values to the collector.
g.Go(c.Report(func(report func(int)) error {
report(10)
report(20)
report(30)
return nil
}))
```

Once all the tasks are done, call `Wait` to stop the collector and wait for it
to finish:
Once all the tasks derived from the collector are done, it is safe to access
the values accumulated by the callback:

```go
g.Wait() // wait for tasks to finish
c.Wait() // wait for the collector to finish

// Now you can access the values accumulated by c.
fmt.Println(sum)
Expand Down
75 changes: 46 additions & 29 deletions collector.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
package taskgroup

import "sync"

// A Collector collects values reported by task functions and delivers them to
// an accumulator function.
type Collector[T any] struct {
ch chan<- T
s *Single[error]
μ sync.Mutex
handle func(T)
}

// report delivers v to the callback under the lock.
func (c *Collector[T]) report(v T) {
c.μ.Lock()
defer c.μ.Unlock()
c.handle(v)
}

// NewCollector creates a new collector that delivers task values to the
// specified accumulator function. The collector serializes calls to value, so
// that it is safe for the function to access shared state without a lock. The
// caller must call Wait when the collector is no longer needed, even if it has
// not been used.
func NewCollector[T any](value func(T)) *Collector[T] {
ch := make(chan T)
s := Go(NoError(func() {
for v := range ch {
value(v)
}
}))
return &Collector[T]{ch: ch, s: s}
}
// that it is safe for the function to access shared state without a lock.
//
// The tasks created from a collector do not return until all the values
// reported by the underlying function have been processed by the accumulator.
func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} }

// Wait stops the collector and blocks until it has finished processing.
// It is safe to call Wait multiple times from a single goroutine.
// Note that after Wait has been called, c is no longer valid.
func (c *Collector[T]) Wait() {
if c.ch != nil {
close(c.ch)
c.ch = nil
c.s.Wait()
}
}
// Wait waits until the collector has finished processing.
//
// Deprecated: This method is now a noop; it is safe but unnecessary to call
// it. The state serviced by c is settled once all the goroutines writing to
// the collector have returned. It may be removed in a future version.
func (c *Collector[T]) Wait() {}

// Task returns a Task wrapping a call to f. If f reports an error, that error
// is propagated as the return value of the task; otherwise, the non-error
Expand All @@ -42,21 +40,40 @@ func (c *Collector[T]) Task(f func() (T, error)) Task {
if err != nil {
return err
}
c.ch <- v
c.report(v)
return nil
}
}

// Report returns a task wrapping a call to f, which is passed a function that
// sends results to the accumulator. The report function does not return until
// the accumulator has finished processing the value.
func (c *Collector[T]) Report(f func(report func(T)) error) Task {
return func() error { return f(c.report) }
}

// Stream returns a task wrapping a call to f, which is passed a channel on
// which results can be sent to the accumulator.
// which results can be sent to the accumulator. Each call to Stream starts a
// goroutine to process the values on the channel.
//
// Note: f must not close its argument channel.
// Deprecated: Tasks that wish to deliver multiple values should use Report
// instead, which does not spawn a goroutine. This method may be removed in a
// future version.
func (c *Collector[T]) Stream(f func(chan<- T) error) Task {
return func() error { return f(c.ch) }
return func() error {
ch := make(chan T)
s := Go(NoError(func() {
for v := range ch {
c.report(v)
}
}))
defer func() { close(ch); s.Wait() }()
return f(ch)
}
}

// NoError returns a Task wrapping a call to f. The resulting task reports a
// nil error for all calls.
func (c *Collector[T]) NoError(f func() T) Task {
return NoError(func() { c.ch <- f() })
return NoError(func() { c.report(f()) })
}
22 changes: 9 additions & 13 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,45 +189,41 @@ func ExampleCollector() {

// Wait for the searchers to finish, then signal the collector to stop.
g.Wait()
c.Wait()

// Now get the final result.
fmt.Println(total)
// Output:
// 325
}

func ExampleCollector_Stream() {
func ExampleCollector_Report() {
type val struct {
who string
v int
}
c := taskgroup.NewCollector(func(z val) { fmt.Println(z.who, z.v) })

err := taskgroup.New(nil).
// The Stream method passes its argument a channel where it may report
// multiple values to the collector.
Go(c.Stream(func(zs chan<- val) error {
// The Report method passes its argument a function to report multiple
// values to the collector.
Go(c.Report(func(report func(v val)) error {
for i := 0; i < 3; i++ {
zs <- val{"even", 2 * i}
report(val{"even", 2 * i})
}
return nil
})).
// Multiple streams are fine.
Go(c.Stream(func(zs chan<- val) error {
// Multiple reporters are fine.
Go(c.Report(func(report func(v val)) error {
for i := 0; i < 3; i++ {
zs <- val{"odd", 2*i + 1}
report(val{"odd", 2*i + 1})
}
// An error reported by a stream is propagated just like any other
// task error.
// An error from a reporter is propagated like any other task error.
return errors.New("no bueno")
})).
Wait()
if err == nil || err.Error() != "no bueno" {
log.Fatalf("Unexpected error: %v", err)
}

c.Wait()
// Unordered output:
// even 0
// odd 1
Expand Down
42 changes: 36 additions & 6 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,39 @@ func TestSingleTask(t *testing.T) {
})
}

func TestWaitMoreTasks(t *testing.T) {
defer leaktest.Check(t)()

var results int
coll := taskgroup.NewCollector(func(int) {
results++
})

g := taskgroup.New(nil)

// Test that if a task spawns more tasks on its own recognizance, waiting
// correctly waits for all of them provided we do not let the group go empty
// before all the tasks are spawned.
var countdown func(int) int
countdown = func(n int) int {
if n > 1 {
// The subordinate task, if there is one, is started before this one
// exits, ensuring the group is kept "afloat".
g.Go(coll.NoError(func() int {
return countdown(n - 1)
}))
}
return n
}

g.Go(coll.NoError(func() int { return countdown(15) }))
g.Wait()

if results != 15 {
t.Errorf("Got %d results, want 10", results)
}
}

func TestSingleResult(t *testing.T) {
defer leaktest.Check(t)()

Expand Down Expand Up @@ -250,30 +283,27 @@ func TestCollector(t *testing.T) {
g.Go(c.NoError(func() int { return v }))
}
}

g.Wait() // wait for tasks to finish
c.Wait() // wait for collector

if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
}

func TestCollector_Stream(t *testing.T) {
func TestCollector_Report(t *testing.T) {
var sum int
c := taskgroup.NewCollector(func(v int) { sum += v })

g := taskgroup.New(nil).Go(c.Stream(func(vs chan<- int) error {
g := taskgroup.New(nil).Go(c.Report(func(report func(v int)) error {
for _, v := range shuffled(10) {
vs <- v
report(v)
}
return nil
}))

if err := g.Wait(); err != nil {
t.Errorf("Unexpected error from group: %v", err)
}
c.Wait()
if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
Expand Down

0 comments on commit 3287b54

Please sign in to comment.