diff --git a/lib/trie/branch/encode.go b/lib/trie/branch/encode.go index 04d7bc9cf9..92afb5d71b 100644 --- a/lib/trie/branch/encode.go +++ b/lib/trie/branch/encode.go @@ -8,6 +8,7 @@ import ( "fmt" "hash" "io" + "runtime" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/encode" @@ -120,12 +121,7 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { } } - const parallel = false // TODO - if parallel { - err = encodeChildrenInParallel(b.Children, buffer) - } else { - err = encodeChildrenSequentially(b.Children, buffer) - } + err = encodeChildrenOpportunisticParallel(b.Children, buffer) if err != nil { return fmt.Errorf("cannot encode children of branch: %w", err) } @@ -133,30 +129,64 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { return nil } -func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err error) { - type result struct { - index int - buffer *bytes.Buffer - err error +type encodingAsyncResult struct { + index int + buffer *bytes.Buffer + err error +} + +func runEncodeChild(child node.Node, index int, + results chan<- encodingAsyncResult, rateLimit <-chan struct{}) { + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + // buffer is put back in the pool after processing its + // data in the select block below. + + err := encodeChild(child, buffer) + + results <- encodingAsyncResult{ + index: index, + buffer: buffer, + err: err, + } + if rateLimit != nil { + // Only run if runEncodeChild is launched + // in its own goroutine. + <-rateLimit } +} + +var parallelLimit = runtime.NumCPU() + +var parallelEncodingRateLimit = make(chan struct{}, parallelLimit) - resultsCh := make(chan result) +// encodeChildrenOpportunisticParallel encodes children in parallel eventually. +// Leaves are encoded in a blocking way, and branches are encoded in separate +// goroutines IF they are less than the parallelLimit number of goroutines already +// running. This is designed to limit the total number of goroutines in order to +// avoid using too much memory on the stack. +func encodeChildrenOpportunisticParallel(children [16]node.Node, buffer io.Writer) (err error) { + // Buffered channels since children might be encoded in this + // goroutine or another one. + resultsCh := make(chan encodingAsyncResult, len(children)) for i, child := range children { - go func(index int, child node.Node) { - buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - buffer.Reset() - // buffer is put back in the pool after processing its - // data in the select block below. - - err := encodeChild(child, buffer) - - resultsCh <- result{ - index: index, - buffer: buffer, - err: err, - } - }(i, child) + if isNodeNil(child) || child.Type() == node.LeafType { + runEncodeChild(child, i, resultsCh, nil) + continue + } + + // Branch child + select { + case parallelEncodingRateLimit <- struct{}{}: + // We have a goroutine available to encode + // the branch in parallel. + go runEncodeChild(child, i, resultsCh, parallelEncodingRateLimit) + default: + // we reached the maximum parallel goroutines + // so encode this branch in this goroutine + runEncodeChild(child, i, resultsCh, nil) + } } currentIndex := 0 @@ -173,7 +203,7 @@ func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err err for currentIndex < len(children) && resultBuffers[currentIndex] != nil { bufferSlice := resultBuffers[currentIndex].Bytes() - if len(bufferSlice) > 0 { + if err == nil && len(bufferSlice) > 0 { // note buffer.Write copies the byte slice given as argument _, writeErr := buffer.Write(bufferSlice) if writeErr != nil && err == nil { @@ -210,17 +240,20 @@ func encodeChildrenSequentially(children [16]node.Node, buffer io.Writer) (err e return nil } -func encodeChild(child node.Node, buffer io.Writer) (err error) { - var isNil bool - switch impl := child.(type) { +func isNodeNil(n node.Node) (isNil bool) { + switch impl := n.(type) { case *Branch: isNil = impl == nil case *leaf.Leaf: isNil = impl == nil default: - isNil = child == nil + isNil = n == nil } - if isNil { + return isNil +} + +func encodeChild(child node.Node, buffer io.Writer) (err error) { + if isNodeNil(child) { return nil } diff --git a/lib/trie/branch/encode_test.go b/lib/trie/branch/encode_test.go index 7153329bc5..9d0a988e52 100644 --- a/lib/trie/branch/encode_test.go +++ b/lib/trie/branch/encode_test.go @@ -146,7 +146,7 @@ func Test_Branch_Encode(t *testing.T) { wrappedErr: errTest, errMessage: "cannot write encoded value to buffer: test error", }, - "buffer write error for children encoded sequentially": { + "buffer write error for children encoding": { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, @@ -175,10 +175,10 @@ func Test_Branch_Encode(t *testing.T) { }, wrappedErr: errTest, errMessage: "cannot encode children of branch: " + - "cannot encode child at index 3: " + - "failed to write child to buffer: test error", + "cannot write encoding of child at index 3: " + + "test error", }, - "success with sequential children encoding": { + "success with children encoding": { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, @@ -241,7 +241,7 @@ func Test_Branch_Encode(t *testing.T) { } } -func Test_encodeChildrenInParallel(t *testing.T) { +func Test_encodeChildrenOpportunisticParallel(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -329,7 +329,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { previousCall = call } - err := encodeChildrenInParallel(testCase.children, buffer) + err := encodeChildrenOpportunisticParallel(testCase.children, buffer) if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr)