From 3287b54c3ffb371a16fb03b3ea73d9871d8c5371 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sun, 17 Mar 2024 21:35:01 -0700 Subject: [PATCH] Simplify the implementation of the Collector. Instead of maintaining a separate goroutine to synchronize delivery of values, rework the collector to use a plain sync.Mutex. This: - Greatly simplifies the code (with one exception, noted below). - Eliminates the need for a separate goroutine to service values. Each task now handles its own service, mediated by the collector. That, in turn: - Eliminates the need to Wait for the Collector: Once all the goroutines running tasks in the collector have exited, the state is fully settled. The Wait method is now a no-op, and is marked as deprecated. In addition, add a new Report method, replacing Stream. Instead of a channel, tasks using this method accepts a report function that sends values to the collector. The report function ensures control does not return to the task until the reported value has been serviced, which allows tasks to ensure they do not exit until all their values have been addressed. The Stream method still works, but is deprecated. To preserve its interface, each Stream call now spins up a new goroutine to service the values from its task. This is wasteful, but easily replaced by switching to Report. Co-Authored-By: David Anderson --- README.md | 31 ++++++++++++-------- collector.go | 75 +++++++++++++++++++++++++++++------------------ example_test.go | 22 ++++++-------- taskgroup_test.go | 42 ++++++++++++++++++++++---- 4 files changed, 109 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 298a256..245789a 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ A `*taskgroup.Group` represents a group of goroutines working on related tasks. New tasks can be added to the group at will, and the caller can wait until all -tasks are complete. Errors are automatically collected and delivered to a -user-provided callback in a single goroutine. This does not replace the full +tasks are complete. Errors are automatically collected and delivered +synchronously to a user-provided callback. This does not replace the full generality of Go's built-in features, but it simplifies some of the plumbing for common concurrent tasks. @@ -229,7 +229,7 @@ start(task4) // blocks until one of the previous tasks is finished ## Solo Tasks -In some cases it is useful to start a solo background task to handle an +In some cases it is useful to start a single background task to handle an isolated concern. For example, suppose we want to read a file into a buffer while we take care of some other work. Rather than creating a whole group for a single goroutine, we can create a solo task using the `Go` constructor. @@ -270,13 +270,11 @@ var sum int c := taskgroup.NewCollector(func(v int) { sum += v }) ``` -Internally, a `Collector` wraps a solo task and a channel to receive results. - -The `Task` and `NoError` methods of the collector `c` can then be used to wrap -a function that reports a value. If the function reports an error, that error -is returned from the task as usual. Otherwise, its non-error value is given to -the callback. As in the above example, calls to the function are serialized so -that it is safe to access state without additional locking: +The `Task`, `NoError`, and `Report` methods of `c` wrap a function that yields +a value into a task. If the function reports an error, that error is returned +from the task as usual. Otherwise, its non-error value is given to the +accumulator callback. As in the above example, calls to the function are +serialized so that it is safe to access state without additional locking: ```go // Report an error, no value for the collector. @@ -291,14 +289,21 @@ g.Go(c.Task(func() (int, error) { // Report a random integer to the collector. g.Go(c.NoError(func() int { return rand.Intn(1000) }) + +// Report multiple values to the collector. +g.Go(c.Report(func(report func(int)) error { + report(10) + report(20) + report(30) + return nil +})) ``` -Once all the tasks are done, call `Wait` to stop the collector and wait for it -to finish: +Once all the tasks derived from the collector are done, it is safe to access +the values accumulated by the callback: ```go g.Wait() // wait for tasks to finish -c.Wait() // wait for the collector to finish // Now you can access the values accumulated by c. fmt.Println(sum) diff --git a/collector.go b/collector.go index 4490924..6dc9b03 100644 --- a/collector.go +++ b/collector.go @@ -1,37 +1,35 @@ package taskgroup +import "sync" + // A Collector collects values reported by task functions and delivers them to // an accumulator function. type Collector[T any] struct { - ch chan<- T - s *Single[error] + μ sync.Mutex + handle func(T) +} + +// report delivers v to the callback under the lock. +func (c *Collector[T]) report(v T) { + c.μ.Lock() + defer c.μ.Unlock() + c.handle(v) } // NewCollector creates a new collector that delivers task values to the // specified accumulator function. The collector serializes calls to value, so -// that it is safe for the function to access shared state without a lock. The -// caller must call Wait when the collector is no longer needed, even if it has -// not been used. -func NewCollector[T any](value func(T)) *Collector[T] { - ch := make(chan T) - s := Go(NoError(func() { - for v := range ch { - value(v) - } - })) - return &Collector[T]{ch: ch, s: s} -} +// that it is safe for the function to access shared state without a lock. +// +// The tasks created from a collector do not return until all the values +// reported by the underlying function have been processed by the accumulator. +func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} } -// Wait stops the collector and blocks until it has finished processing. -// It is safe to call Wait multiple times from a single goroutine. -// Note that after Wait has been called, c is no longer valid. -func (c *Collector[T]) Wait() { - if c.ch != nil { - close(c.ch) - c.ch = nil - c.s.Wait() - } -} +// Wait waits until the collector has finished processing. +// +// Deprecated: This method is now a noop; it is safe but unnecessary to call +// it. The state serviced by c is settled once all the goroutines writing to +// the collector have returned. It may be removed in a future version. +func (c *Collector[T]) Wait() {} // Task returns a Task wrapping a call to f. If f reports an error, that error // is propagated as the return value of the task; otherwise, the non-error @@ -42,21 +40,40 @@ func (c *Collector[T]) Task(f func() (T, error)) Task { if err != nil { return err } - c.ch <- v + c.report(v) return nil } } +// Report returns a task wrapping a call to f, which is passed a function that +// sends results to the accumulator. The report function does not return until +// the accumulator has finished processing the value. +func (c *Collector[T]) Report(f func(report func(T)) error) Task { + return func() error { return f(c.report) } +} + // Stream returns a task wrapping a call to f, which is passed a channel on -// which results can be sent to the accumulator. +// which results can be sent to the accumulator. Each call to Stream starts a +// goroutine to process the values on the channel. // -// Note: f must not close its argument channel. +// Deprecated: Tasks that wish to deliver multiple values should use Report +// instead, which does not spawn a goroutine. This method may be removed in a +// future version. func (c *Collector[T]) Stream(f func(chan<- T) error) Task { - return func() error { return f(c.ch) } + return func() error { + ch := make(chan T) + s := Go(NoError(func() { + for v := range ch { + c.report(v) + } + })) + defer func() { close(ch); s.Wait() }() + return f(ch) + } } // NoError returns a Task wrapping a call to f. The resulting task reports a // nil error for all calls. func (c *Collector[T]) NoError(f func() T) Task { - return NoError(func() { c.ch <- f() }) + return NoError(func() { c.report(f()) }) } diff --git a/example_test.go b/example_test.go index 51e44f2..2c64992 100644 --- a/example_test.go +++ b/example_test.go @@ -189,7 +189,6 @@ func ExampleCollector() { // Wait for the searchers to finish, then signal the collector to stop. g.Wait() - c.Wait() // Now get the final result. fmt.Println(total) @@ -197,7 +196,7 @@ func ExampleCollector() { // 325 } -func ExampleCollector_Stream() { +func ExampleCollector_Report() { type val struct { who string v int @@ -205,29 +204,26 @@ func ExampleCollector_Stream() { c := taskgroup.NewCollector(func(z val) { fmt.Println(z.who, z.v) }) err := taskgroup.New(nil). - // The Stream method passes its argument a channel where it may report - // multiple values to the collector. - Go(c.Stream(func(zs chan<- val) error { + // The Report method passes its argument a function to report multiple + // values to the collector. + Go(c.Report(func(report func(v val)) error { for i := 0; i < 3; i++ { - zs <- val{"even", 2 * i} + report(val{"even", 2 * i}) } return nil })). - // Multiple streams are fine. - Go(c.Stream(func(zs chan<- val) error { + // Multiple reporters are fine. + Go(c.Report(func(report func(v val)) error { for i := 0; i < 3; i++ { - zs <- val{"odd", 2*i + 1} + report(val{"odd", 2*i + 1}) } - // An error reported by a stream is propagated just like any other - // task error. + // An error from a reporter is propagated like any other task error. return errors.New("no bueno") })). Wait() if err == nil || err.Error() != "no bueno" { log.Fatalf("Unexpected error: %v", err) } - - c.Wait() // Unordered output: // even 0 // odd 1 diff --git a/taskgroup_test.go b/taskgroup_test.go index 9ae6a67..f54234e 100644 --- a/taskgroup_test.go +++ b/taskgroup_test.go @@ -208,6 +208,39 @@ func TestSingleTask(t *testing.T) { }) } +func TestWaitMoreTasks(t *testing.T) { + defer leaktest.Check(t)() + + var results int + coll := taskgroup.NewCollector(func(int) { + results++ + }) + + g := taskgroup.New(nil) + + // Test that if a task spawns more tasks on its own recognizance, waiting + // correctly waits for all of them provided we do not let the group go empty + // before all the tasks are spawned. + var countdown func(int) int + countdown = func(n int) int { + if n > 1 { + // The subordinate task, if there is one, is started before this one + // exits, ensuring the group is kept "afloat". + g.Go(coll.NoError(func() int { + return countdown(n - 1) + })) + } + return n + } + + g.Go(coll.NoError(func() int { return countdown(15) })) + g.Wait() + + if results != 15 { + t.Errorf("Got %d results, want 10", results) + } +} + func TestSingleResult(t *testing.T) { defer leaktest.Check(t)() @@ -250,22 +283,20 @@ func TestCollector(t *testing.T) { g.Go(c.NoError(func() int { return v })) } } - g.Wait() // wait for tasks to finish - c.Wait() // wait for collector if want := (10 * 11) / 2; sum != want { t.Errorf("Final result: got %d, want %d", sum, want) } } -func TestCollector_Stream(t *testing.T) { +func TestCollector_Report(t *testing.T) { var sum int c := taskgroup.NewCollector(func(v int) { sum += v }) - g := taskgroup.New(nil).Go(c.Stream(func(vs chan<- int) error { + g := taskgroup.New(nil).Go(c.Report(func(report func(v int)) error { for _, v := range shuffled(10) { - vs <- v + report(v) } return nil })) @@ -273,7 +304,6 @@ func TestCollector_Stream(t *testing.T) { if err := g.Wait(); err != nil { t.Errorf("Unexpected error from group: %v", err) } - c.Wait() if want := (10 * 11) / 2; sum != want { t.Errorf("Final result: got %d, want %d", sum, want) }