diff --git a/README.md b/README.md index 298a256..245789a 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ A `*taskgroup.Group` represents a group of goroutines working on related tasks. New tasks can be added to the group at will, and the caller can wait until all -tasks are complete. Errors are automatically collected and delivered to a -user-provided callback in a single goroutine. This does not replace the full +tasks are complete. Errors are automatically collected and delivered +synchronously to a user-provided callback. This does not replace the full generality of Go's built-in features, but it simplifies some of the plumbing for common concurrent tasks. @@ -229,7 +229,7 @@ start(task4) // blocks until one of the previous tasks is finished ## Solo Tasks -In some cases it is useful to start a solo background task to handle an +In some cases it is useful to start a single background task to handle an isolated concern. For example, suppose we want to read a file into a buffer while we take care of some other work. Rather than creating a whole group for a single goroutine, we can create a solo task using the `Go` constructor. @@ -270,13 +270,11 @@ var sum int c := taskgroup.NewCollector(func(v int) { sum += v }) ``` -Internally, a `Collector` wraps a solo task and a channel to receive results. - -The `Task` and `NoError` methods of the collector `c` can then be used to wrap -a function that reports a value. If the function reports an error, that error -is returned from the task as usual. Otherwise, its non-error value is given to -the callback. As in the above example, calls to the function are serialized so -that it is safe to access state without additional locking: +The `Task`, `NoError`, and `Report` methods of `c` wrap a function that yields +a value into a task. If the function reports an error, that error is returned +from the task as usual. Otherwise, its non-error value is given to the +accumulator callback. As in the above example, calls to the function are +serialized so that it is safe to access state without additional locking: ```go // Report an error, no value for the collector. @@ -291,14 +289,21 @@ g.Go(c.Task(func() (int, error) { // Report a random integer to the collector. g.Go(c.NoError(func() int { return rand.Intn(1000) }) + +// Report multiple values to the collector. +g.Go(c.Report(func(report func(int)) error { + report(10) + report(20) + report(30) + return nil +})) ``` -Once all the tasks are done, call `Wait` to stop the collector and wait for it -to finish: +Once all the tasks derived from the collector are done, it is safe to access +the values accumulated by the callback: ```go g.Wait() // wait for tasks to finish -c.Wait() // wait for the collector to finish // Now you can access the values accumulated by c. fmt.Println(sum) diff --git a/collector.go b/collector.go index 4490924..6dc9b03 100644 --- a/collector.go +++ b/collector.go @@ -1,37 +1,35 @@ package taskgroup +import "sync" + // A Collector collects values reported by task functions and delivers them to // an accumulator function. type Collector[T any] struct { - ch chan<- T - s *Single[error] + μ sync.Mutex + handle func(T) +} + +// report delivers v to the callback under the lock. +func (c *Collector[T]) report(v T) { + c.μ.Lock() + defer c.μ.Unlock() + c.handle(v) } // NewCollector creates a new collector that delivers task values to the // specified accumulator function. The collector serializes calls to value, so -// that it is safe for the function to access shared state without a lock. The -// caller must call Wait when the collector is no longer needed, even if it has -// not been used. -func NewCollector[T any](value func(T)) *Collector[T] { - ch := make(chan T) - s := Go(NoError(func() { - for v := range ch { - value(v) - } - })) - return &Collector[T]{ch: ch, s: s} -} +// that it is safe for the function to access shared state without a lock. +// +// The tasks created from a collector do not return until all the values +// reported by the underlying function have been processed by the accumulator. +func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} } -// Wait stops the collector and blocks until it has finished processing. -// It is safe to call Wait multiple times from a single goroutine. -// Note that after Wait has been called, c is no longer valid. -func (c *Collector[T]) Wait() { - if c.ch != nil { - close(c.ch) - c.ch = nil - c.s.Wait() - } -} +// Wait waits until the collector has finished processing. +// +// Deprecated: This method is now a noop; it is safe but unnecessary to call +// it. The state serviced by c is settled once all the goroutines writing to +// the collector have returned. It may be removed in a future version. +func (c *Collector[T]) Wait() {} // Task returns a Task wrapping a call to f. If f reports an error, that error // is propagated as the return value of the task; otherwise, the non-error @@ -42,21 +40,40 @@ func (c *Collector[T]) Task(f func() (T, error)) Task { if err != nil { return err } - c.ch <- v + c.report(v) return nil } } +// Report returns a task wrapping a call to f, which is passed a function that +// sends results to the accumulator. The report function does not return until +// the accumulator has finished processing the value. +func (c *Collector[T]) Report(f func(report func(T)) error) Task { + return func() error { return f(c.report) } +} + // Stream returns a task wrapping a call to f, which is passed a channel on -// which results can be sent to the accumulator. +// which results can be sent to the accumulator. Each call to Stream starts a +// goroutine to process the values on the channel. // -// Note: f must not close its argument channel. +// Deprecated: Tasks that wish to deliver multiple values should use Report +// instead, which does not spawn a goroutine. This method may be removed in a +// future version. func (c *Collector[T]) Stream(f func(chan<- T) error) Task { - return func() error { return f(c.ch) } + return func() error { + ch := make(chan T) + s := Go(NoError(func() { + for v := range ch { + c.report(v) + } + })) + defer func() { close(ch); s.Wait() }() + return f(ch) + } } // NoError returns a Task wrapping a call to f. The resulting task reports a // nil error for all calls. func (c *Collector[T]) NoError(f func() T) Task { - return NoError(func() { c.ch <- f() }) + return NoError(func() { c.report(f()) }) } diff --git a/example_test.go b/example_test.go index 51e44f2..2c64992 100644 --- a/example_test.go +++ b/example_test.go @@ -189,7 +189,6 @@ func ExampleCollector() { // Wait for the searchers to finish, then signal the collector to stop. g.Wait() - c.Wait() // Now get the final result. fmt.Println(total) @@ -197,7 +196,7 @@ func ExampleCollector() { // 325 } -func ExampleCollector_Stream() { +func ExampleCollector_Report() { type val struct { who string v int @@ -205,29 +204,26 @@ func ExampleCollector_Stream() { c := taskgroup.NewCollector(func(z val) { fmt.Println(z.who, z.v) }) err := taskgroup.New(nil). - // The Stream method passes its argument a channel where it may report - // multiple values to the collector. - Go(c.Stream(func(zs chan<- val) error { + // The Report method passes its argument a function to report multiple + // values to the collector. + Go(c.Report(func(report func(v val)) error { for i := 0; i < 3; i++ { - zs <- val{"even", 2 * i} + report(val{"even", 2 * i}) } return nil })). - // Multiple streams are fine. - Go(c.Stream(func(zs chan<- val) error { + // Multiple reporters are fine. + Go(c.Report(func(report func(v val)) error { for i := 0; i < 3; i++ { - zs <- val{"odd", 2*i + 1} + report(val{"odd", 2*i + 1}) } - // An error reported by a stream is propagated just like any other - // task error. + // An error from a reporter is propagated like any other task error. return errors.New("no bueno") })). Wait() if err == nil || err.Error() != "no bueno" { log.Fatalf("Unexpected error: %v", err) } - - c.Wait() // Unordered output: // even 0 // odd 1 diff --git a/taskgroup_test.go b/taskgroup_test.go index 9ae6a67..f54234e 100644 --- a/taskgroup_test.go +++ b/taskgroup_test.go @@ -208,6 +208,39 @@ func TestSingleTask(t *testing.T) { }) } +func TestWaitMoreTasks(t *testing.T) { + defer leaktest.Check(t)() + + var results int + coll := taskgroup.NewCollector(func(int) { + results++ + }) + + g := taskgroup.New(nil) + + // Test that if a task spawns more tasks on its own recognizance, waiting + // correctly waits for all of them provided we do not let the group go empty + // before all the tasks are spawned. + var countdown func(int) int + countdown = func(n int) int { + if n > 1 { + // The subordinate task, if there is one, is started before this one + // exits, ensuring the group is kept "afloat". + g.Go(coll.NoError(func() int { + return countdown(n - 1) + })) + } + return n + } + + g.Go(coll.NoError(func() int { return countdown(15) })) + g.Wait() + + if results != 15 { + t.Errorf("Got %d results, want 10", results) + } +} + func TestSingleResult(t *testing.T) { defer leaktest.Check(t)() @@ -250,22 +283,20 @@ func TestCollector(t *testing.T) { g.Go(c.NoError(func() int { return v })) } } - g.Wait() // wait for tasks to finish - c.Wait() // wait for collector if want := (10 * 11) / 2; sum != want { t.Errorf("Final result: got %d, want %d", sum, want) } } -func TestCollector_Stream(t *testing.T) { +func TestCollector_Report(t *testing.T) { var sum int c := taskgroup.NewCollector(func(v int) { sum += v }) - g := taskgroup.New(nil).Go(c.Stream(func(vs chan<- int) error { + g := taskgroup.New(nil).Go(c.Report(func(report func(v int)) error { for _, v := range shuffled(10) { - vs <- v + report(v) } return nil })) @@ -273,7 +304,6 @@ func TestCollector_Stream(t *testing.T) { if err := g.Wait(); err != nil { t.Errorf("Unexpected error from group: %v", err) } - c.Wait() if want := (10 * 11) / 2; sum != want { t.Errorf("Final result: got %d, want %d", sum, want) }