diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 5c3ee95fbc..e2a32af50c 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -8,6 +8,7 @@ import ( "fmt" "hash" "io" + "runtime" "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/pools" @@ -113,12 +114,7 @@ func (b *Branch) Encode(buffer Buffer) (err error) { } } - const parallel = false // TODO - if parallel { - err = encodeChildrenInParallel(b.Children, buffer) - } else { - err = encodeChildrenSequentially(b.Children, buffer) - } + err = encodeChildrenOpportunisticParallel(b.Children, buffer) if err != nil { return fmt.Errorf("cannot encode children of branch: %w", err) } @@ -126,30 +122,64 @@ func (b *Branch) Encode(buffer Buffer) (err error) { return nil } -func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { - type result struct { - index int - buffer *bytes.Buffer - err error +type encodingAsyncResult struct { + index int + buffer *bytes.Buffer + err error +} + +func runEncodeChild(child Node, index int, + results chan<- encodingAsyncResult, rateLimit <-chan struct{}) { + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + // buffer is put back in the pool after processing its + // data in the select block below. + + err := encodeChild(child, buffer) + + results <- encodingAsyncResult{ + index: index, + buffer: buffer, + err: err, + } + if rateLimit != nil { + // Only run if runEncodeChild is launched + // in its own goroutine. + <-rateLimit } +} + +var parallelLimit = runtime.NumCPU() + +var parallelEncodingRateLimit = make(chan struct{}, parallelLimit) - resultsCh := make(chan result) +// encodeChildrenOpportunisticParallel encodes children in parallel eventually. +// Leaves are encoded in a blocking way, and branches are encoded in separate +// goroutines IF they are less than the parallelLimit number of goroutines already +// running. This is designed to limit the total number of goroutines in order to +// avoid using too much memory on the stack. +func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (err error) { + // Buffered channels since children might be encoded in this + // goroutine or another one. + resultsCh := make(chan encodingAsyncResult, len(children)) for i, child := range children { - go func(index int, child Node) { - buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - buffer.Reset() - // buffer is put back in the pool after processing its - // data in the select block below. - - err := encodeChild(child, buffer) - - resultsCh <- result{ - index: index, - buffer: buffer, - err: err, - } - }(i, child) + if isNodeNil(child) || child.Type() == LeafType { + runEncodeChild(child, i, resultsCh, nil) + continue + } + + // Branch child + select { + case parallelEncodingRateLimit <- struct{}{}: + // We have a goroutine available to encode + // the branch in parallel. + go runEncodeChild(child, i, resultsCh, parallelEncodingRateLimit) + default: + // we reached the maximum parallel goroutines + // so encode this branch in this goroutine + runEncodeChild(child, i, resultsCh, nil) + } } currentIndex := 0 @@ -166,7 +196,7 @@ func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { for currentIndex < len(children) && resultBuffers[currentIndex] != nil { bufferSlice := resultBuffers[currentIndex].Bytes() - if len(bufferSlice) > 0 { + if err == nil && len(bufferSlice) > 0 { // note buffer.Write copies the byte slice given as argument _, writeErr := buffer.Write(bufferSlice) if writeErr != nil && err == nil { @@ -203,17 +233,20 @@ func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) return nil } -func encodeChild(child Node, buffer io.Writer) (err error) { - var isNil bool - switch impl := child.(type) { +func isNodeNil(n Node) (isNil bool) { + switch impl := n.(type) { case *Branch: isNil = impl == nil case *Leaf: isNil = impl == nil default: - isNil = child == nil + isNil = n == nil } - if isNil { + return isNil +} + +func encodeChild(child Node, buffer io.Writer) (err error) { + if isNodeNil(child) { return nil } diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index 60be03aa18..733e77cd13 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -123,7 +123,7 @@ func Test_Branch_Encode(t *testing.T) { wrappedErr: errTest, errMessage: "cannot write encoded value to buffer: test error", }, - "buffer write error for children encoded sequentially": { + "buffer write error for children encoding": { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, @@ -152,10 +152,10 @@ func Test_Branch_Encode(t *testing.T) { }, wrappedErr: errTest, errMessage: "cannot encode children of branch: " + - "cannot encode child at index 3: " + - "failed to write child to buffer: test error", + "cannot write encoding of child at index 3: " + + "test error", }, - "success with sequential children encoding": { + "success with children encoding": { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, @@ -218,7 +218,7 @@ func Test_Branch_Encode(t *testing.T) { } } -func Test_encodeChildrenInParallel(t *testing.T) { +func Test_encodeChildrenOpportunisticParallel(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -306,7 +306,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { previousCall = call } - err := encodeChildrenInParallel(testCase.children, buffer) + err := encodeChildrenOpportunisticParallel(testCase.children, buffer) if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr)