From a3ae30bd3a3dbf1f58ec8bbee58c1b4cd40994d6 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 16 Dec 2021 17:46:40 +0200 Subject: [PATCH] chore(lib/trie): refactor encoding and hash related code in trie package (#2077) * Chore(packages): move node interface, leaf & branch implementations, encoding and decoding functions in `internal/trie/node` * Chore(packages): create `internal/trie/recorder` subpackage (#2082) * Chore(packages): create `internal/trie/pools` subpackage * Chore(packages): create `internal/trie/codec` subpackage * Chore(tests): add tests with near full coverage * Chore(errors): improve error wrapping on trie node implementations and encoding/decoding * Optimization: use `sync.Pool` for header byte reading * Optimization: encode headers directly to buffer * Code addition: `GetValue() []byte` method for node interface * Code addition: `GetKey() []byte` method for node interface * Chore(comments): add and clarify existing comments * Chore(api): unexport node implementation fields: `Generation`, `Dirty`, `Encoding` and `Hash` * Minor change: trie `string()` method does not cache encoding in nodes. This is only used for debugging. Co-authored-by: Kishan Sagathiya --- internal/trie/codec/nibbles.go | 48 + internal/trie/codec/nibbles_test.go | 142 +++ internal/trie/node/branch.go | 50 + internal/trie/node/branch_encode.go | 231 ++++ internal/trie/node/branch_encode_test.go | 642 +++++++++++ internal/trie/node/branch_test.go | 95 ++ internal/trie/node/buffer.go | 16 + internal/trie/node/buffer_mock_test.go | 77 ++ internal/trie/node/children.go | 27 + internal/trie/node/children_test.go | 120 ++ internal/trie/node/copy.go | 77 ++ internal/trie/node/copy_test.go | 127 +++ internal/trie/node/decode.go | 145 +++ internal/trie/node/decode_test.go | 309 +++++ internal/trie/node/dirty.go | 24 + internal/trie/node/dirty_test.go | 150 +++ internal/trie/node/encode_decode_test.go | 89 ++ internal/trie/node/encode_doc.go | 28 + internal/trie/node/encode_test.go | 14 + internal/trie/node/generation.go | 24 + internal/trie/node/generation_test.go | 50 + internal/trie/node/hash.go | 135 +++ internal/trie/node/hash_test.go | 254 +++++ internal/trie/node/header.go | 67 ++ internal/trie/node/header_test.go | 245 ++++ internal/trie/node/key.go | 123 ++ internal/trie/node/key_test.go | 334 ++++++ internal/trie/node/leaf.go | 48 + internal/trie/node/leaf_encode.go | 117 ++ internal/trie/node/leaf_encode_test.go | 296 +++++ internal/trie/node/leaf_test.go | 77 ++ internal/trie/node/node.go | 22 + internal/trie/node/reader_mock_test.go | 49 + internal/trie/node/types.go | 17 + internal/trie/node/value.go | 18 + internal/trie/node/value_test.go | 30 + .../trie/node}/writer_mock_test.go | 4 +- internal/trie/pools/pools.go | 51 + internal/trie/record/node.go | 10 + internal/trie/record/recorder.go | 27 + internal/trie/record/recorder_test.go | 118 ++ lib/trie/bytesBuffer_mock_test.go | 77 -- lib/trie/codec.go | 67 -- lib/trie/codec_test.go | 80 -- lib/trie/database.go | 116 +- lib/trie/hash.go | 348 ------ lib/trie/hash_test.go | 1012 ----------------- lib/trie/lookup.go | 27 +- lib/trie/node.go | 535 +-------- lib/trie/node_mock_test.go | 183 --- lib/trie/node_test.go | 327 +----- lib/trie/print.go | 45 +- lib/trie/proof.go | 13 +- lib/trie/proof_test.go | 2 +- lib/trie/readwriter_mock_test.go | 64 -- lib/trie/recorder.go | 34 - lib/trie/trie.go | 388 ++++--- lib/trie/trie_test.go | 66 +- 58 files changed, 4863 insertions(+), 3048 deletions(-) create mode 100644 internal/trie/codec/nibbles.go create mode 100644 internal/trie/codec/nibbles_test.go create mode 100644 internal/trie/node/branch.go create mode 100644 internal/trie/node/branch_encode.go create mode 100644 internal/trie/node/branch_encode_test.go create mode 100644 internal/trie/node/branch_test.go create mode 100644 internal/trie/node/buffer.go create mode 100644 internal/trie/node/buffer_mock_test.go create mode 100644 internal/trie/node/children.go create mode 100644 internal/trie/node/children_test.go create mode 100644 internal/trie/node/copy.go create mode 100644 internal/trie/node/copy_test.go create mode 100644 internal/trie/node/decode.go create mode 100644 internal/trie/node/decode_test.go create mode 100644 internal/trie/node/dirty.go create mode 100644 internal/trie/node/dirty_test.go create mode 100644 internal/trie/node/encode_decode_test.go create mode 100644 internal/trie/node/encode_doc.go create mode 100644 internal/trie/node/encode_test.go create mode 100644 internal/trie/node/generation.go create mode 100644 internal/trie/node/generation_test.go create mode 100644 internal/trie/node/hash.go create mode 100644 internal/trie/node/hash_test.go create mode 100644 internal/trie/node/header.go create mode 100644 internal/trie/node/header_test.go create mode 100644 internal/trie/node/key.go create mode 100644 internal/trie/node/key_test.go create mode 100644 internal/trie/node/leaf.go create mode 100644 internal/trie/node/leaf_encode.go create mode 100644 internal/trie/node/leaf_encode_test.go create mode 100644 internal/trie/node/leaf_test.go create mode 100644 internal/trie/node/node.go create mode 100644 internal/trie/node/reader_mock_test.go create mode 100644 internal/trie/node/types.go create mode 100644 internal/trie/node/value.go create mode 100644 internal/trie/node/value_test.go rename {lib/trie => internal/trie/node}/writer_mock_test.go (95%) create mode 100644 internal/trie/pools/pools.go create mode 100644 internal/trie/record/node.go create mode 100644 internal/trie/record/recorder.go create mode 100644 internal/trie/record/recorder_test.go delete mode 100644 lib/trie/bytesBuffer_mock_test.go delete mode 100644 lib/trie/codec.go delete mode 100644 lib/trie/codec_test.go delete mode 100644 lib/trie/hash.go delete mode 100644 lib/trie/hash_test.go delete mode 100644 lib/trie/node_mock_test.go delete mode 100644 lib/trie/readwriter_mock_test.go delete mode 100644 lib/trie/recorder.go diff --git a/internal/trie/codec/nibbles.go b/internal/trie/codec/nibbles.go new file mode 100644 index 0000000000..7b6f9bd4de --- /dev/null +++ b/internal/trie/codec/nibbles.go @@ -0,0 +1,48 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +// NibblesToKeyLE converts a slice of nibbles with length k into a +// Little Endian byte slice. +// It assumes nibbles are already in Little Endian and does not rearrange nibbles. +// If the length of the input is odd, the result is +// [ 0000 in[0] | in[1] in[2] | ... | in[k-2] in[k-1] ] +// Otherwise, the result is +// [ in[0] in[1] | ... | in[k-2] in[k-1] ] +func NibblesToKeyLE(nibbles []byte) []byte { + if len(nibbles)%2 == 0 { + keyLE := make([]byte, len(nibbles)/2) + for i := 0; i < len(nibbles); i += 2 { + keyLE[i/2] = (nibbles[i] << 4 & 0xf0) | (nibbles[i+1] & 0xf) + } + return keyLE + } + + keyLE := make([]byte, len(nibbles)/2+1) + keyLE[0] = nibbles[0] + for i := 2; i < len(nibbles); i += 2 { + keyLE[i/2] = (nibbles[i-1] << 4 & 0xf0) | (nibbles[i] & 0xf) + } + + return keyLE +} + +// KeyLEToNibbles converts a Little Endian byte slice into nibbles. +// It assumes bytes are already in Little Endian and does not rearrange nibbles. +func KeyLEToNibbles(in []byte) (nibbles []byte) { + if len(in) == 0 { + return []byte{} + } else if len(in) == 1 && in[0] == 0 { + return []byte{0, 0} + } + + l := len(in) * 2 + nibbles = make([]byte, l) + for i, b := range in { + nibbles[2*i] = b / 16 + nibbles[2*i+1] = b % 16 + } + + return nibbles +} diff --git a/internal/trie/codec/nibbles_test.go b/internal/trie/codec/nibbles_test.go new file mode 100644 index 0000000000..fa2bbf4fdd --- /dev/null +++ b/internal/trie/codec/nibbles_test.go @@ -0,0 +1,142 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NibblesToKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibbles []byte + keyLE []byte + }{ + "nil nibbles": { + keyLE: []byte{}, + }, + "empty nibbles": { + nibbles: []byte{}, + keyLE: []byte{}, + }, + "0xF 0xF": { + nibbles: []byte{0xF, 0xF}, + keyLE: []byte{0xFF}, + }, + "0x3 0xa 0x0 0x5": { + nibbles: []byte{0x3, 0xa, 0x0, 0x5}, + keyLE: []byte{0x3a, 0x05}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, + keyLE: []byte{0xaa, 0xff, 0x01}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, + keyLE: []byte{0xaa, 0xff, 0x01, 0xc2}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, + keyLE: []byte{0xa, 0xaf, 0xf0, 0x1c}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := NibblesToKeyLE(testCase.nibbles) + + assert.Equal(t, testCase.keyLE, keyLE) + }) + } +} + +func Test_KeyLEToNibbles(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in []byte + nibbles []byte + }{ + "nil input": { + nibbles: []byte{}, + }, + "empty input": { + in: []byte{}, + nibbles: []byte{}, + }, + "0x0": { + in: []byte{0x0}, + nibbles: []byte{0, 0}}, + "0xFF": { + in: []byte{0xFF}, + nibbles: []byte{0xF, 0xF}}, + "0x3a 0x05": { + in: []byte{0x3a, 0x05}, + nibbles: []byte{0x3, 0xa, 0x0, 0x5}}, + "0xAA 0xFF 0x01": { + in: []byte{0xAA, 0xFF, 0x01}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, + "0xAA 0xFF 0x01 0xc2": { + in: []byte{0xAA, 0xFF, 0x01, 0xc2}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, + "0xAA 0xFF 0x01 0xc0": { + in: []byte{0xAA, 0xFF, 0x01, 0xc0}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nibbles := KeyLEToNibbles(testCase.in) + + assert.Equal(t, testCase.nibbles, nibbles) + }) + } +} + +func Test_NibblesKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibblesToEncode []byte + nibblesDecoded []byte + }{ + "empty input": { + nibblesToEncode: []byte{}, + nibblesDecoded: []byte{}, + }, + "one byte": { + nibblesToEncode: []byte{1}, + nibblesDecoded: []byte{0, 1}, + }, + "two bytes": { + nibblesToEncode: []byte{1, 2}, + nibblesDecoded: []byte{1, 2}, + }, + "three bytes": { + nibblesToEncode: []byte{1, 2, 3}, + nibblesDecoded: []byte{0, 1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := NibblesToKeyLE(testCase.nibblesToEncode) + nibblesDecoded := KeyLEToNibbles(keyLE) + + assert.Equal(t, testCase.nibblesDecoded, nibblesDecoded) + }) + } +} diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go new file mode 100644 index 0000000000..7f3422a6f1 --- /dev/null +++ b/internal/trie/node/branch.go @@ -0,0 +1,50 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "fmt" + "sync" + + "github.com/ChainSafe/gossamer/lib/common" +) + +var _ Node = (*Branch)(nil) + +// Branch is a branch in the trie. +type Branch struct { + Key []byte // partial key + Children [16]Node + Value []byte + // dirty is true when the branch differs + // from the node stored in the database. + dirty bool + hashDigest []byte + encoding []byte + // generation is incremented on every trie Snapshot() call. + // Each node also contain a certain generation number, + // which is updated to match the trie generation once they are + // inserted, moved or iterated over. + generation uint64 + sync.RWMutex +} + +// NewBranch creates a new branch using the arguments given. +func NewBranch(key, value []byte, dirty bool, generation uint64) *Branch { + return &Branch{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } +} + +func (b *Branch) String() string { + if len(b.Value) > 1024 { + return fmt.Sprintf("branch key=0x%x childrenBitmap=%b value (hashed)=0x%x dirty=%t", + b.Key, b.ChildrenBitmap(), common.MustBlake2bHash(b.Value), b.dirty) + } + return fmt.Sprintf("branch key=0x%x childrenBitmap=%b value=0x%x dirty=%t", + b.Key, b.ChildrenBitmap(), b.Value, b.dirty) +} diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go new file mode 100644 index 0000000000..badd3556f0 --- /dev/null +++ b/internal/trie/node/branch_encode.go @@ -0,0 +1,231 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "fmt" + "hash" + "io" + + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/pools" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// ScaleEncodeHash hashes the node (blake2b sum on encoded value) +// and then SCALE encodes it. This is used to encode children +// nodes of branches. +func (b *Branch) ScaleEncodeHash() (encoding []byte, err error) { + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.DigestBuffers.Put(buffer) + + err = b.hash(buffer) + if err != nil { + return nil, fmt.Errorf("cannot hash branch: %w", err) + } + + encoding, err = scale.Marshal(buffer.Bytes()) + if err != nil { + return nil, fmt.Errorf("cannot scale encode hashed branch: %w", err) + } + + return encoding, nil +} + +func (b *Branch) hash(digestBuffer io.Writer) (err error) { + encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + encodingBuffer.Reset() + defer pools.EncodingBuffers.Put(encodingBuffer) + + err = b.Encode(encodingBuffer) + if err != nil { + return fmt.Errorf("cannot encode leaf: %w", err) + } + + // if length of encoded branch is less than 32 bytes, do not hash + if encodingBuffer.Len() < 32 { + _, err = digestBuffer.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot write encoded branch to buffer: %w", err) + } + return nil + } + + // otherwise, hash encoded node + hasher := pools.Hashers.Get().(hash.Hash) + hasher.Reset() + defer pools.Hashers.Put(hasher) + + // Note: using the sync.Pool's buffer is useful here. + _, err = hasher.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot hash encoded node: %w", err) + } + + _, err = digestBuffer.Write(hasher.Sum(nil)) + if err != nil { + return fmt.Errorf("cannot write hash sum of branch to buffer: %w", err) + } + return nil +} + +// Encode encodes a branch with the encoding specified at the top of this package +// to the buffer given. +func (b *Branch) Encode(buffer Buffer) (err error) { + if !b.dirty && b.encoding != nil { + _, err = buffer.Write(b.encoding) + if err != nil { + return fmt.Errorf("cannot write stored encoding to buffer: %w", err) + } + return nil + } + + err = b.encodeHeader(buffer) + if err != nil { + return fmt.Errorf("cannot encode header: %w", err) + } + + keyLE := codec.NibblesToKeyLE(b.Key) + _, err = buffer.Write(keyLE) + if err != nil { + return fmt.Errorf("cannot write encoded key to buffer: %w", err) + } + + childrenBitmap := common.Uint16ToBytes(b.ChildrenBitmap()) + _, err = buffer.Write(childrenBitmap) + if err != nil { + return fmt.Errorf("cannot write children bitmap to buffer: %w", err) + } + + if b.Value != nil { + bytes, err := scale.Marshal(b.Value) + if err != nil { + return fmt.Errorf("cannot scale encode value: %w", err) + } + + _, err = buffer.Write(bytes) + if err != nil { + return fmt.Errorf("cannot write encoded value to buffer: %w", err) + } + } + + const parallel = false // TODO Done in pull request #2081 + if parallel { + err = encodeChildrenInParallel(b.Children, buffer) + } else { + err = encodeChildrenSequentially(b.Children, buffer) + } + if err != nil { + return fmt.Errorf("cannot encode children of branch: %w", err) + } + + return nil +} + +func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { + type result struct { + index int + buffer *bytes.Buffer + err error + } + + resultsCh := make(chan result) + + for i, child := range children { + go func(index int, child Node) { + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + // buffer is put back in the pool after processing its + // data in the select block below. + + err := encodeChild(child, buffer) + + resultsCh <- result{ + index: index, + buffer: buffer, + err: err, + } + }(i, child) + } + + currentIndex := 0 + resultBuffers := make([]*bytes.Buffer, len(children)) + for range children { + result := <-resultsCh + if result.err != nil && err == nil { // only set the first error we get + err = result.err + } + + resultBuffers[result.index] = result.buffer + + // write as many completed buffers to the result buffer. + for currentIndex < len(children) && + resultBuffers[currentIndex] != nil { + bufferSlice := resultBuffers[currentIndex].Bytes() + if len(bufferSlice) > 0 { + // note buffer.Write copies the byte slice given as argument + _, writeErr := buffer.Write(bufferSlice) + if writeErr != nil && err == nil { + err = fmt.Errorf( + "cannot write encoding of child at index %d: %w", + currentIndex, writeErr) + } + } + + pools.EncodingBuffers.Put(resultBuffers[currentIndex]) + resultBuffers[currentIndex] = nil + + currentIndex++ + } + } + + for _, buffer := range resultBuffers { + if buffer == nil { // already emptied and put back in pool + continue + } + pools.EncodingBuffers.Put(buffer) + } + + return err +} + +func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) { + for i, child := range children { + err = encodeChild(child, buffer) + if err != nil { + return fmt.Errorf("cannot encode child at index %d: %w", i, err) + } + } + return nil +} + +func encodeChild(child Node, buffer io.Writer) (err error) { + var isNil bool + switch impl := child.(type) { + case *Branch: + isNil = impl == nil + case *Leaf: + isNil = impl == nil + default: + isNil = child == nil + } + if isNil { + return nil + } + + scaleEncodedChild, err := child.ScaleEncodeHash() + if err != nil { + return fmt.Errorf("failed to hash and scale encode child: %w", err) + } + + _, err = buffer.Write(scaleEncodedChild) + if err != nil { + return fmt.Errorf("failed to write child to buffer: %w", err) + } + + return nil +} diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go new file mode 100644 index 0000000000..9c1fc50703 --- /dev/null +++ b/internal/trie/node/branch_encode_test.go @@ -0,0 +1,642 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Branch_ScaleEncodeHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + encoding []byte + wrappedErr error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + encoding: []byte{0xc, 0x80, 0x0, 0x0}, + }, + "non empty branch": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + }, + encoding: []byte{0x2c, 0xc2, 0x12, 0x4, 0x0, 0x8, 0x3, 0x4, 0xc, 0x41, 0x9, 0x0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, err := testCase.branch.ScaleEncodeHash() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + assert.Equal(t, testCase.encoding, encoding) + }) + } +} + +func Test_Branch_hash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + write writeCall + errWrapped error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + write: writeCall{ + written: []byte{128, 0, 0}, + }, + }, + "less than 32 bytes encoding": { + branch: &Branch{ + Key: []byte{1, 2}, + }, + write: writeCall{ + written: []byte{130, 18, 0, 0}, + }, + }, + "less than 32 bytes encoding write error": { + branch: &Branch{ + Key: []byte{1, 2}, + }, + write: writeCall{ + written: []byte{130, 18, 0, 0}, + err: errTest, + }, + errWrapped: errTest, + errMessage: "cannot write encoded branch to buffer: test error", + }, + "more than 32 bytes encoding": { + branch: &Branch{ + Key: repeatBytes(100, 1), + }, + write: writeCall{ + written: []byte{ + 70, 102, 188, 24, 31, 68, 86, 114, + 95, 156, 225, 138, 175, 254, 176, 251, + 81, 84, 193, 40, 11, 234, 142, 233, + 69, 250, 158, 86, 72, 228, 66, 46}, + }, + }, + "more than 32 bytes encoding write error": { + branch: &Branch{ + Key: repeatBytes(100, 1), + }, + write: writeCall{ + written: []byte{ + 70, 102, 188, 24, 31, 68, 86, 114, + 95, 156, 225, 138, 175, 254, 176, 251, + 81, 84, 193, 40, 11, 234, 142, 233, + 69, 250, 158, 86, 72, 228, 66, 46}, + err: errTest, + }, + errWrapped: errTest, + errMessage: "cannot write hash sum of branch to buffer: test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + digestBuffer := NewMockWriter(ctrl) + digestBuffer.EXPECT().Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + + err := testCase.branch.hash(digestBuffer) + + if testCase.errWrapped != nil { + assert.ErrorIs(t, err, testCase.errWrapped) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_Branch_Encode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + writes []writeCall + wrappedErr error + errMessage string + }{ + "clean branch with encoding": { + branch: &Branch{ + encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { // stored encoding + written: []byte{1, 2, 3}, + }, + }, + }, + "write error for clean branch with encoding": { + branch: &Branch{ + encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { // stored encoding + written: []byte{1, 2, 3}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write stored encoding to buffer: test error", + }, + "header encoding error": { + branch: &Branch{ + Key: make([]byte, 63+(1<<16)), + }, + writes: []writeCall{ + { // header + written: []byte{191}, + }, + }, + wrappedErr: ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + }, + "buffer write error for encoded key": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded key to buffer: test error", + }, + "buffer write error for children bitmap": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write children bitmap to buffer: test error", + }, + "buffer write error for value": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded value to buffer: test error", + }, + "buffer write error for children encoded sequentially": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // children + written: []byte{12, 65, 9, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot encode children of branch: " + + "cannot encode child at index 3: " + + "failed to write child to buffer: test error", + }, + "success with sequential children encoding": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // first children + written: []byte{12, 65, 9, 0}, + }, + { // second children + written: []byte{12, 65, 11, 0}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockBuffer(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := testCase.branch.Encode(buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_encodeChildrenInParallel(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + children [16]Node + writes []writeCall + wrappedErr error + errMessage string + }{ + "no children": {}, + "first child not nil": { + children: [16]Node{ + &Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "last child not nil": { + children: [16]Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "first two children not nil": { + children: [16]Node{ + &Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{2}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + { + written: []byte{12, 65, 2, 0}, + }, + }, + }, + "encoding error": { + children: [16]Node{ + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, + &Leaf{ + Key: []byte{1}, + }, + nil, nil, nil, nil, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoding of child at index 11: " + + "test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := encodeChildrenInParallel(testCase.children, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_encodeChildrenSequentially(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + children [16]Node + writes []writeCall + wrappedErr error + errMessage string + }{ + "no children": {}, + "first child not nil": { + children: [16]Node{ + &Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "last child not nil": { + children: [16]Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "first two children not nil": { + children: [16]Node{ + &Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{2}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + { + written: []byte{12, 65, 2, 0}, + }, + }, + }, + "encoding error": { + children: [16]Node{ + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, + &Leaf{ + Key: []byte{1}, + }, + nil, nil, nil, nil, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot encode child at index 11: " + + "failed to write child to buffer: test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := encodeChildrenSequentially(testCase.children, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_encodeChild(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + child Node + writeCall bool + write writeCall + wrappedErr error + errMessage string + }{ + "nil node": {}, + "nil leaf": { + child: (*Leaf)(nil), + }, + "nil branch": { + child: (*Branch)(nil), + }, + "empty leaf child": { + child: &Leaf{}, + writeCall: true, + write: writeCall{ + written: []byte{8, 64, 0}, + }, + }, + "empty branch child": { + child: &Branch{}, + writeCall: true, + write: writeCall{ + written: []byte{12, 128, 0, 0}, + }, + }, + "buffer write error": { + child: &Branch{}, + writeCall: true, + write: writeCall{ + written: []byte{12, 128, 0, 0}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "failed to write child to buffer: test error", + }, + "leaf child": { + child: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + }, + writeCall: true, + write: writeCall{ + written: []byte{16, 65, 1, 4, 2}, + }, + }, + "branch child": { + child: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + Children: [16]Node{ + nil, nil, &Leaf{ + Key: []byte{5}, + Value: []byte{6}, + }, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{44, 193, 1, 4, 0, 4, 2, 16, 65, 5, 4, 6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + + if testCase.writeCall { + buffer.EXPECT(). + Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + } + + err := encodeChild(testCase.child, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/trie/node/branch_test.go b/internal/trie/node/branch_test.go new file mode 100644 index 0000000000..a7d4591c32 --- /dev/null +++ b/internal/trie/node/branch_test.go @@ -0,0 +1,95 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewBranch(t *testing.T) { + t.Parallel() + + key := []byte{1, 2} + value := []byte{3, 4} + const dirty = true + const generation = 9 + + branch := NewBranch(key, value, dirty, generation) + + expectedBranch := &Branch{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } + assert.Equal(t, expectedBranch, branch) + + // Check modifying passed slice modifies branch slices + key[0] = 11 + value[0] = 13 + assert.Equal(t, expectedBranch, branch) +} + +func Test_Branch_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + s string + }{ + "empty branch": { + branch: &Branch{}, + s: "branch key=0x childrenBitmap=0 value=0x dirty=false", + }, + "branch with value smaller than 1024": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + Children: [16]Node{ + nil, nil, nil, + &Leaf{}, + nil, nil, nil, + &Branch{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, + }, + }, + s: "branch key=0x0102 childrenBitmap=100010001000 value=0x0304 dirty=true", + }, + "branch with value higher than 1024": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: make([]byte, 1025), + dirty: true, + Children: [16]Node{ + nil, nil, nil, + &Leaf{}, + nil, nil, nil, + &Branch{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, + }, + }, + s: "branch key=0x0102 childrenBitmap=100010001000 " + + "value (hashed)=0x307861663233363133353361303538646238383034626337353735323831663131663735313265326331346336373032393864306232336630396538386266333066 " + //nolint:lll + "dirty=true", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.branch.String() + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/trie/node/buffer.go b/internal/trie/node/buffer.go new file mode 100644 index 0000000000..c4a2e74cf1 --- /dev/null +++ b/internal/trie/node/buffer.go @@ -0,0 +1,16 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import "io" + +//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE . Buffer +//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer + +// Buffer is an interface with some methods of *bytes.Buffer. +type Buffer interface { + io.Writer + Len() int + Bytes() []byte +} diff --git a/internal/trie/node/buffer_mock_test.go b/internal/trie/node/buffer_mock_test.go new file mode 100644 index 0000000000..8977a1ed52 --- /dev/null +++ b/internal/trie/node/buffer_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/internal/trie/node (interfaces: Buffer) + +// Package node is a generated GoMock package. +package node + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockBuffer is a mock of Buffer interface. +type MockBuffer struct { + ctrl *gomock.Controller + recorder *MockBufferMockRecorder +} + +// MockBufferMockRecorder is the mock recorder for MockBuffer. +type MockBufferMockRecorder struct { + mock *MockBuffer +} + +// NewMockBuffer creates a new mock instance. +func NewMockBuffer(ctrl *gomock.Controller) *MockBuffer { + mock := &MockBuffer{ctrl: ctrl} + mock.recorder = &MockBufferMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBuffer) EXPECT() *MockBufferMockRecorder { + return m.recorder +} + +// Bytes mocks base method. +func (m *MockBuffer) Bytes() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Bytes") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Bytes indicates an expected call of Bytes. +func (mr *MockBufferMockRecorder) Bytes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockBuffer)(nil).Bytes)) +} + +// Len mocks base method. +func (m *MockBuffer) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockBufferMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockBuffer)(nil).Len)) +} + +// Write mocks base method. +func (m *MockBuffer) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockBufferMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockBuffer)(nil).Write), arg0) +} diff --git a/internal/trie/node/children.go b/internal/trie/node/children.go new file mode 100644 index 0000000000..be4f9e47ea --- /dev/null +++ b/internal/trie/node/children.go @@ -0,0 +1,27 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// ChildrenBitmap returns the 16 bit bitmap +// of the children in the branch. +func (b *Branch) ChildrenBitmap() (bitmap uint16) { + for i := uint(0); i < 16; i++ { + if b.Children[i] == nil { + continue + } + bitmap |= 1 << i + } + return bitmap +} + +// NumChildren returns the total number of children +// in the branch. +func (b *Branch) NumChildren() (count int) { + for i := 0; i < 16; i++ { + if b.Children[i] != nil { + count++ + } + } + return count +} diff --git a/internal/trie/node/children_test.go b/internal/trie/node/children_test.go new file mode 100644 index 0000000000..4b60039656 --- /dev/null +++ b/internal/trie/node/children_test.go @@ -0,0 +1,120 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_ChildrenBitmap(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + bitmap uint16 + }{ + "no children": { + branch: &Branch{}, + }, + "index 0": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + }, + }, + bitmap: 1, + }, + "index 0 and 4": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + nil, nil, nil, + &Leaf{}, + }, + }, + bitmap: 1<<4 + 1, + }, + "index 0, 4 and 15": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{}, + }, + }, + bitmap: 1<<15 + 1<<4 + 1, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + bitmap := testCase.branch.ChildrenBitmap() + + assert.Equal(t, testCase.bitmap, bitmap) + }) + } +} + +func Test_Branch_NumChildren(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + count int + }{ + "zero": { + branch: &Branch{}, + }, + "one": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + }, + }, + count: 1, + }, + "two": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + nil, nil, nil, + &Leaf{}, + }, + }, + count: 2, + }, + "three": { + branch: &Branch{ + Children: [16]Node{ + &Leaf{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{}, + }, + }, + count: 3, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + count := testCase.branch.NumChildren() + + assert.Equal(t, testCase.count, count) + }) + } +} diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go new file mode 100644 index 0000000000..1a59f24d21 --- /dev/null +++ b/internal/trie/node/copy.go @@ -0,0 +1,77 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// Copy deep copies the branch. +func (b *Branch) Copy() Node { + b.RLock() + defer b.RUnlock() + + cpy := &Branch{ + Children: b.Children, // copy interface pointers + dirty: b.dirty, + generation: b.generation, + } + copy(cpy.Key, b.Key) + + if b.Key != nil { + cpy.Key = make([]byte, len(b.Key)) + copy(cpy.Key, b.Key) + } + + // nil and []byte{} are encoded differently, watch out! + if b.Value != nil { + cpy.Value = make([]byte, len(b.Value)) + copy(cpy.Value, b.Value) + } + + if b.hashDigest != nil { + cpy.hashDigest = make([]byte, len(b.hashDigest)) + copy(cpy.hashDigest, b.hashDigest) + } + + if b.encoding != nil { + cpy.encoding = make([]byte, len(b.encoding)) + copy(cpy.encoding, b.encoding) + } + + return cpy +} + +// Copy deep copies the leaf. +func (l *Leaf) Copy() Node { + l.RLock() + defer l.RUnlock() + + l.encodingMu.RLock() + defer l.encodingMu.RUnlock() + + cpy := &Leaf{ + dirty: l.dirty, + generation: l.generation, + } + + if l.Key != nil { + cpy.Key = make([]byte, len(l.Key)) + copy(cpy.Key, l.Key) + } + + // nil and []byte{} are encoded differently, watch out! + if l.Value != nil { + cpy.Value = make([]byte, len(l.Value)) + copy(cpy.Value, l.Value) + } + + if l.hashDigest != nil { + cpy.hashDigest = make([]byte, len(l.hashDigest)) + copy(cpy.hashDigest, l.hashDigest) + } + + if l.encoding != nil { + cpy.encoding = make([]byte, len(l.encoding)) + copy(cpy.encoding, l.encoding) + } + + return cpy +} diff --git a/internal/trie/node/copy_test.go b/internal/trie/node/copy_test.go new file mode 100644 index 0000000000..bff0f409c2 --- /dev/null +++ b/internal/trie/node/copy_test.go @@ -0,0 +1,127 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testForSliceModif(t *testing.T, original, copied []byte) { + t.Helper() + require.Equal(t, len(original), len(copied)) + if len(copied) == 0 { + // cannot test for modification + return + } + original[0]++ + assert.NotEqual(t, copied, original) +} + +func Test_Branch_Copy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + expectedBranch *Branch + }{ + "empty branch": { + branch: &Branch{}, + expectedBranch: &Branch{}, + }, + "non empty branch": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + expectedBranch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeCopy := testCase.branch.Copy() + + branchCopy, ok := nodeCopy.(*Branch) + require.True(t, ok) + + assert.Equal(t, testCase.expectedBranch, branchCopy) + testForSliceModif(t, testCase.branch.Key, branchCopy.Key) + testForSliceModif(t, testCase.branch.Value, branchCopy.Value) + testForSliceModif(t, testCase.branch.hashDigest, branchCopy.hashDigest) + testForSliceModif(t, testCase.branch.encoding, branchCopy.encoding) + + testCase.branch.Children[15] = &Leaf{Key: []byte("modified")} + assert.NotEqual(t, branchCopy.Children, testCase.branch.Children) + }) + } +} + +func Test_Leaf_Copy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + expectedLeaf *Leaf + }{ + "empty leaf": { + leaf: &Leaf{}, + expectedLeaf: &Leaf{}, + }, + "non empty leaf": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + expectedLeaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeCopy := testCase.leaf.Copy() + + leafCopy, ok := nodeCopy.(*Leaf) + require.True(t, ok) + + assert.Equal(t, testCase.expectedLeaf, leafCopy) + testForSliceModif(t, testCase.leaf.Key, leafCopy.Key) + testForSliceModif(t, testCase.leaf.Value, leafCopy.Value) + testForSliceModif(t, testCase.leaf.hashDigest, leafCopy.hashDigest) + testForSliceModif(t, testCase.leaf.encoding, leafCopy.encoding) + }) + } +} diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go new file mode 100644 index 0000000000..007cae95c3 --- /dev/null +++ b/internal/trie/node/decode.go @@ -0,0 +1,145 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/internal/trie/pools" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +var ( + ErrReadHeaderByte = errors.New("cannot read header byte") + ErrUnknownNodeType = errors.New("unknown node type") + ErrNodeTypeIsNotABranch = errors.New("node type is not a branch") + ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") + ErrDecodeValue = errors.New("cannot decode value") + ErrReadChildrenBitmap = errors.New("cannot read children bitmap") + ErrDecodeChildHash = errors.New("cannot decode child hash") +) + +// Decode decodes a node from a reader. +// For branch decoding, see the comments on decodeBranch. +// For leaf decoding, see the comments on decodeLeaf. +func Decode(reader io.Reader) (n Node, err error) { + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() + _, err = reader.Read(oneByteBuf) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + } + header := oneByteBuf[0] + + nodeType := Type(header >> 6) + switch nodeType { + case LeafType: + n, err = decodeLeaf(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case BranchType, BranchWithValueType: + n, err = decodeBranch(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) + } +} + +// decodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +// Note that since the encoded branch stores the hash of the children nodes, we are not +// reconstructing the child nodes from the encoding. This function instead stubs where the +// children are known to be with an empty leaf. The children nodes hashes are then used to +// find other values using the persistent database. +func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { + nodeType := Type(header >> 6) + if nodeType != BranchType && nodeType != BranchWithValueType { + return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) + } + + branch = new(Branch) + + keyLen := header & keyLenOffset + branch.Key, err = decodeKey(reader, keyLen) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + childrenBitmap := make([]byte, 2) + _, err = reader.Read(childrenBitmap) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadChildrenBitmap, err) + } + + sd := scale.NewDecoder(reader) + + if nodeType == BranchWithValueType { + var value []byte + // branch w/ value + err := sd.Decode(&value) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) + } + branch.Value = value + } + + for i := 0; i < 16; i++ { + if (childrenBitmap[i/8]>>(i%8))&1 != 1 { + continue + } + var hash []byte + err := sd.Decode(&hash) + if err != nil { + return nil, fmt.Errorf("%w: at index %d: %s", + ErrDecodeChildHash, i, err) + } + + branch.Children[i] = &Leaf{ + hashDigest: hash, + } + } + + branch.dirty = true + + return branch, nil +} + +// decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { + nodeType := Type(header >> 6) + if nodeType != LeafType { + return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) + } + + leaf = &Leaf{ + dirty: true, + } + + keyLen := header & keyLenOffset + leaf.Key, err = decodeKey(reader, keyLen) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + sd := scale.NewDecoder(reader) + var value []byte + err = sd.Decode(&value) + if err != nil && !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) + } + + if len(value) > 0 { + leaf.Value = value + } + + return leaf, nil +} diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go new file mode 100644 index 0000000000..9998dde230 --- /dev/null +++ b/internal/trie/node/decode_test.go @@ -0,0 +1,309 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func concatByteSlices(slices [][]byte) (concatenated []byte) { + length := 0 + for i := range slices { + length += len(slices[i]) + } + concatenated = make([]byte, 0, length) + for _, slice := range slices { + concatenated = append(concatenated, slice...) + } + return concatenated +} + +func Test_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + n Node + errWrapped error + errMessage string + }{ + "no data": { + reader: bytes.NewReader(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "unknown node type": { + reader: bytes.NewReader([]byte{0}), + errWrapped: ErrUnknownNodeType, + errMessage: "unknown node type: 0", + }, + "leaf decoding error": { + reader: bytes.NewReader([]byte{ + 65, // node type 1 (leaf) and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", + }, + "leaf success": { + reader: bytes.NewReader( + append( + []byte{ + 65, // node type 1 (leaf) and key length 1 + 9, // key data + }, + scaleEncodeBytes(t, 1, 2, 3)..., + ), + ), + n: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3}, + dirty: true, + }, + }, + "branch decoding error": { + reader: bytes.NewReader([]byte{ + 129, // node type 2 (branch without value) and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", + }, + "branch success": { + reader: bytes.NewReader( + []byte{ + 129, // node type 2 (branch without value) and key length 1 + 9, // key data + 0, 0, // no children bitmap + }, + ), + n: &Branch{ + Key: []byte{9}, + dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + n, err := Decode(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.n, n) + }) + } +} + +func Test_decodeBranch(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + header byte + branch *Branch + errWrapped error + errMessage string + }{ + "no data with header 1": { + reader: bytes.NewBuffer(nil), + header: 65, + errWrapped: ErrNodeTypeIsNotABranch, + errMessage: "node type is not a branch: 1", + }, + "key decoding error": { + reader: bytes.NewBuffer([]byte{ + // missing key data byte + }), + header: 129, // node type 2 (branch without value) and key length 1 + errWrapped: ErrReadKeyData, + errMessage: "cannot decode key: cannot read key data: EOF", + }, + "children bitmap read error": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + // missing children bitmap 2 bytes + }), + header: 129, // node type 2 (branch without value) and key length 1 + errWrapped: ErrReadChildrenBitmap, + errMessage: "cannot read children bitmap: EOF", + }, + "children decoding error": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + 0, 4, // children bitmap + // missing children scale encoded data + }), + header: 129, // node type 2 (branch without value) and key length 1 + errWrapped: ErrDecodeChildHash, + errMessage: "cannot decode child hash: at index 10: EOF", + }, + "success node type 2": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + { + 9, // key data + 0, 4, // children bitmap + }, + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash + }), + ), + header: 129, // node type 2 (branch without value) and key length 1 + branch: &Branch{ + Key: []byte{9}, + Children: [16]Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{ + hashDigest: []byte{1, 2, 3, 4, 5}, + }, + }, + dirty: true, + }, + }, + "value decoding error for node type 3": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + {9}, // key data + {0, 4}, // children bitmap + // missing encoded branch value + }), + ), + header: 193, // node type 3 (branch with value) and key length 1 + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: EOF", + }, + "success node type 3": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + {9}, // key data + {0, 4}, // children bitmap + scaleEncodeBytes(t, 7, 8, 9), // branch value + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash + }), + ), + header: 193, // node type 3 (branch with value) and key length 1 + branch: &Branch{ + Key: []byte{9}, + Value: []byte{7, 8, 9}, + Children: [16]Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &Leaf{ + hashDigest: []byte{1, 2, 3, 4, 5}, + }, + }, + dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + branch, err := decodeBranch(testCase.reader, testCase.header) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.branch, branch) + }) + } +} + +func Test_decodeLeaf(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + header byte + leaf *Leaf + errWrapped error + errMessage string + }{ + "no data with header 1": { + reader: bytes.NewBuffer(nil), + header: 1, + errWrapped: ErrNodeTypeIsNotALeaf, + errMessage: "node type is not a leaf: 0", + }, + "key decoding error": { + reader: bytes.NewBuffer([]byte{ + // missing key data byte + }), + header: 65, // node type 1 (leaf) and key length 1 + errWrapped: ErrReadKeyData, + errMessage: "cannot decode key: cannot read key data: EOF", + }, + "value decoding error": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + 255, 255, // bad value data + }), + header: 65, // node type 1 (leaf) and key length 1 + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: could not decode invalid integer", + }, + "zero value": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + // missing value data + }), + header: 65, // node type 1 (leaf) and key length 1 + leaf: &Leaf{ + Key: []byte{9}, + dirty: true, + }, + }, + "success": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + {9}, // key data + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data + }), + ), + header: 65, // node type 1 (leaf) and key length 1 + leaf: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3, 4, 5}, + dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + leaf, err := decodeLeaf(testCase.reader, testCase.header) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.leaf, leaf) + }) + } +} diff --git a/internal/trie/node/dirty.go b/internal/trie/node/dirty.go new file mode 100644 index 0000000000..27d0367014 --- /dev/null +++ b/internal/trie/node/dirty.go @@ -0,0 +1,24 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// IsDirty returns the dirty status of the branch. +func (b *Branch) IsDirty() bool { + return b.dirty +} + +// SetDirty sets the dirty status to the branch. +func (b *Branch) SetDirty(dirty bool) { + b.dirty = dirty +} + +// IsDirty returns the dirty status of the leaf. +func (l *Leaf) IsDirty() bool { + return l.dirty +} + +// SetDirty sets the dirty status to the leaf. +func (l *Leaf) SetDirty(dirty bool) { + l.dirty = dirty +} diff --git a/internal/trie/node/dirty_test.go b/internal/trie/node/dirty_test.go new file mode 100644 index 0000000000..ebe9c02fa1 --- /dev/null +++ b/internal/trie/node/dirty_test.go @@ -0,0 +1,150 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_IsDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + dirty bool + }{ + "not dirty": { + branch: &Branch{}, + }, + "dirty": { + branch: &Branch{ + dirty: true, + }, + dirty: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + dirty := testCase.branch.IsDirty() + + assert.Equal(t, testCase.dirty, dirty) + }) + } +} + +func Test_Branch_SetDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + dirty bool + expected *Branch + }{ + "not dirty to not dirty": { + branch: &Branch{}, + expected: &Branch{}, + }, + "not dirty to dirty": { + branch: &Branch{}, + dirty: true, + expected: &Branch{dirty: true}, + }, + "dirty to not dirty": { + branch: &Branch{dirty: true}, + expected: &Branch{}, + }, + "dirty to dirty": { + branch: &Branch{dirty: true}, + dirty: true, + expected: &Branch{dirty: true}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.branch.SetDirty(testCase.dirty) + + assert.Equal(t, testCase.expected, testCase.branch) + }) + } +} + +func Test_Leaf_IsDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + dirty bool + }{ + "not dirty": { + leaf: &Leaf{}, + }, + "dirty": { + leaf: &Leaf{ + dirty: true, + }, + dirty: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + dirty := testCase.leaf.IsDirty() + + assert.Equal(t, testCase.dirty, dirty) + }) + } +} + +func Test_Leaf_SetDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + dirty bool + expected *Leaf + }{ + "not dirty to not dirty": { + leaf: &Leaf{}, + expected: &Leaf{}, + }, + "not dirty to dirty": { + leaf: &Leaf{}, + dirty: true, + expected: &Leaf{dirty: true}, + }, + "dirty to not dirty": { + leaf: &Leaf{dirty: true}, + expected: &Leaf{}, + }, + "dirty to dirty": { + leaf: &Leaf{dirty: true}, + dirty: true, + expected: &Leaf{dirty: true}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.leaf.SetDirty(testCase.dirty) + + assert.Equal(t, testCase.expected, testCase.leaf) + }) + } +} diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go new file mode 100644 index 0000000000..898ed9b7e0 --- /dev/null +++ b/internal/trie/node/encode_decode_test.go @@ -0,0 +1,89 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Branch_Encode_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branchToEncode *Branch + branchDecoded *Branch + }{ + "empty branch": { + branchToEncode: new(Branch), + branchDecoded: &Branch{ + Key: []byte{}, + dirty: true, + }, + }, + "branch with key 5": { + branchToEncode: &Branch{ + Key: []byte{5}, + }, + branchDecoded: &Branch{ + Key: []byte{5}, + dirty: true, + }, + }, + "branch with two bytes key": { + branchToEncode: &Branch{ + Key: []byte{0xf, 0xa}, // note: each byte cannot be larger than 0xf + }, + branchDecoded: &Branch{ + Key: []byte{0xf, 0xa}, + dirty: true, + }, + }, + "branch with child": { + branchToEncode: &Branch{ + Key: []byte{5}, + Children: [16]Node{ + &Leaf{ + Key: []byte{9}, + Value: []byte{10}, + }, + }, + }, + branchDecoded: &Branch{ + Key: []byte{5}, + Children: [16]Node{ + &Leaf{ + hashDigest: []byte{0x41, 0x9, 0x4, 0xa}, + }, + }, + dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + buffer := bytes.NewBuffer(nil) + + err := testCase.branchToEncode.Encode(buffer) + require.NoError(t, err) + + oneBuffer := make([]byte, 1) + _, err = buffer.Read(oneBuffer) + require.NoError(t, err) + header := oneBuffer[0] + + resultBranch, err := decodeBranch(buffer, header) + require.NoError(t, err) + + assert.Equal(t, testCase.branchDecoded, resultBranch) + }) + } +} diff --git a/internal/trie/node/encode_doc.go b/internal/trie/node/encode_doc.go new file mode 100644 index 0000000000..1a8b6a1c0a --- /dev/null +++ b/internal/trie/node/encode_doc.go @@ -0,0 +1,28 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +//nolint:lll +// Modified Merkle-Patricia Trie +// See https://github.com/w3f/polkadot-spec/blob/master/runtime-environment-spec/polkadot_re_spec.pdf for the full specification. +// +// Note that for the following definitions, `|` denotes concatenation +// +// Branch encoding: +// NodeHeader | Extra partial key length | Partial Key | Value +// `NodeHeader` is a byte such that: +// most significant two bits of `NodeHeader`: 10 if branch w/o value, 11 if branch w/ value +// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) +// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length +// `Partial Key` is the branch's key +// `Value` is: Children Bitmap | SCALE Branch node Value | Hash(Enc(Child[i_1])) | Hash(Enc(Child[i_2])) | ... | Hash(Enc(Child[i_n])) +// +// Leaf encoding: +// NodeHeader | Extra partial key length | Partial Key | Value +// `NodeHeader` is a byte such that: +// most significant two bits of `NodeHeader`: 01 +// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) +// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length +// `Partial Key` is the leaf's key +// `Value` is the leaf's SCALE encoded value diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go new file mode 100644 index 0000000000..cc72efc06a --- /dev/null +++ b/internal/trie/node/encode_test.go @@ -0,0 +1,14 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import "errors" + +type writeCall struct { + written []byte + n int // number of bytes + err error +} + +var errTest = errors.New("test error") diff --git a/internal/trie/node/generation.go b/internal/trie/node/generation.go new file mode 100644 index 0000000000..113c283328 --- /dev/null +++ b/internal/trie/node/generation.go @@ -0,0 +1,24 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// SetGeneration sets the generation given to the branch. +func (b *Branch) SetGeneration(generation uint64) { + b.generation = generation +} + +// GetGeneration returns the generation of the branch. +func (b *Branch) GetGeneration() (generation uint64) { + return b.generation +} + +// SetGeneration sets the generation given to the leaf. +func (l *Leaf) SetGeneration(generation uint64) { + l.generation = generation +} + +// GetGeneration returns the generation of the leaf. +func (l *Leaf) GetGeneration() (generation uint64) { + return l.generation +} diff --git a/internal/trie/node/generation_test.go b/internal/trie/node/generation_test.go new file mode 100644 index 0000000000..708d93058e --- /dev/null +++ b/internal/trie/node/generation_test.go @@ -0,0 +1,50 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_SetGeneration(t *testing.T) { + t.Parallel() + + branch := &Branch{ + generation: 1, + } + branch.SetGeneration(2) + assert.Equal(t, &Branch{generation: 2}, branch) +} + +func Test_Branch_GetGeneration(t *testing.T) { + t.Parallel() + + const generation uint64 = 1 + branch := &Branch{ + generation: generation, + } + assert.Equal(t, branch.GetGeneration(), generation) +} + +func Test_Leaf_SetGeneration(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + generation: 1, + } + leaf.SetGeneration(2) + assert.Equal(t, &Leaf{generation: 2}, leaf) +} + +func Test_Leaf_GetGeneration(t *testing.T) { + t.Parallel() + + const generation uint64 = 1 + leaf := &Leaf{ + generation: generation, + } + assert.Equal(t, leaf.GetGeneration(), generation) +} diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go new file mode 100644 index 0000000000..315ad9b738 --- /dev/null +++ b/internal/trie/node/hash.go @@ -0,0 +1,135 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + + "github.com/ChainSafe/gossamer/internal/trie/pools" + "github.com/ChainSafe/gossamer/lib/common" +) + +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. +func (b *Branch) SetEncodingAndHash(enc, hash []byte) { + b.encoding = enc + b.hashDigest = hash +} + +// GetHash returns the hash of the branch. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. +func (b *Branch) GetHash() []byte { + return b.hashDigest +} + +// EncodeAndHash returns the encoding of the branch and +// the blake2b hash digest of the encoding of the branch. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. +func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { + if !b.dirty && b.encoding != nil && b.hashDigest != nil { + return b.encoding, b.hashDigest, nil + } + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = b.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + b.encoding = make([]byte, len(bufferBytes)) + copy(b.encoding, bufferBytes) + encoding = b.encoding // no need to copy + + if buffer.Len() < 32 { + b.hashDigest = make([]byte, len(bufferBytes)) + copy(b.hashDigest, bufferBytes) + hash = b.hashDigest // no need to copy + return encoding, hash, nil + } + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + b.hashDigest = hashArray[:] + hash = b.hashDigest // no need to copy + + return encoding, hash, nil +} + +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. +func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { + l.encodingMu.Lock() + l.encoding = enc + l.encodingMu.Unlock() + l.hashDigest = hash +} + +// GetHash returns the hash of the leaf. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. +func (l *Leaf) GetHash() []byte { + return l.hashDigest +} + +// EncodeAndHash returns the encoding of the leaf and +// the blake2b hash digest of the encoding of the leaf. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. +func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { + l.encodingMu.RLock() + if !l.IsDirty() && l.encoding != nil && l.hashDigest != nil { + l.encodingMu.RUnlock() + return l.encoding, l.hashDigest, nil + } + l.encodingMu.RUnlock() + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = l.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + l.encodingMu.Lock() + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.encoding = make([]byte, len(bufferBytes)) + copy(l.encoding, bufferBytes) + l.encodingMu.Unlock() + encoding = l.encoding // no need to copy + + if len(bufferBytes) < 32 { + l.hashDigest = make([]byte, len(bufferBytes)) + copy(l.hashDigest, bufferBytes) + hash = l.hashDigest // no need to copy + return encoding, hash, nil + } + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + + l.hashDigest = hashArray[:] + hash = l.hashDigest // no need to copy + + return encoding, hash, nil +} diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go new file mode 100644 index 0000000000..26693cd76b --- /dev/null +++ b/internal/trie/node/hash_test.go @@ -0,0 +1,254 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_SetEncodingAndHash(t *testing.T) { + t.Parallel() + + branch := &Branch{ + encoding: []byte{2}, + hashDigest: []byte{3}, + } + branch.SetEncodingAndHash([]byte{4}, []byte{5}) + + expectedBranch := &Branch{ + encoding: []byte{4}, + hashDigest: []byte{5}, + } + assert.Equal(t, expectedBranch, branch) +} + +func Test_Branch_GetHash(t *testing.T) { + t.Parallel() + + branch := &Branch{ + hashDigest: []byte{3}, + } + hash := branch.GetHash() + + expectedHash := []byte{3} + assert.Equal(t, expectedHash, hash) +} + +func Test_Branch_EncodeAndHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + expectedBranch *Branch + encoding []byte + hash []byte + errWrapped error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + expectedBranch: &Branch{ + encoding: []byte{0x80, 0x0, 0x0}, + hashDigest: []byte{0x80, 0x0, 0x0}, + }, + encoding: []byte{0x80, 0x0, 0x0}, + hash: []byte{0x80, 0x0, 0x0}, + }, + "small branch encoding": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedBranch: &Branch{ + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hashDigest: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + "branch dirty with precomputed encoding and hash": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + dirty: true, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedBranch: &Branch{ + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hashDigest: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + "branch not dirty with precomputed encoding and hash": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + dirty: false, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedBranch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + encoding: []byte{3}, + hash: []byte{4}, + }, + "large branch encoding": { + branch: &Branch{ + Key: repeatBytes(65, 7), + }, + expectedBranch: &Branch{ + encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll + hashDigest: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll + }, + encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll + hash: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, hash, err := testCase.branch.EncodeAndHash() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encoding, encoding) + assert.Equal(t, testCase.hash, hash) + }) + } +} + +func Test_Leaf_SetEncodingAndHash(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + encoding: []byte{2}, + hashDigest: []byte{3}, + } + leaf.SetEncodingAndHash([]byte{4}, []byte{5}) + + expectedLeaf := &Leaf{ + encoding: []byte{4}, + hashDigest: []byte{5}, + } + assert.Equal(t, expectedLeaf, leaf) +} + +func Test_Leaf_GetHash(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + hashDigest: []byte{3}, + } + hash := leaf.GetHash() + + expectedHash := []byte{3} + assert.Equal(t, expectedHash, hash) +} + +func Test_Leaf_EncodeAndHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + expectedLeaf *Leaf + encoding []byte + hash []byte + errWrapped error + errMessage string + }{ + "empty leaf": { + leaf: &Leaf{}, + expectedLeaf: &Leaf{ + encoding: []byte{0x40, 0x0}, + hashDigest: []byte{0x40, 0x0}, + }, + encoding: []byte{0x40, 0x0}, + hash: []byte{0x40, 0x0}, + }, + "small leaf encoding": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hashDigest: []byte{0x41, 0x1, 0x4, 0x2}, + }, + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hash: []byte{0x41, 0x1, 0x4, 0x2}, + }, + "leaf dirty with precomputed encoding and hash": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + dirty: true, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hashDigest: []byte{0x41, 0x1, 0x4, 0x2}, + }, + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hash: []byte{0x41, 0x1, 0x4, 0x2}, + }, + "leaf not dirty with precomputed encoding and hash": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + dirty: false, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedLeaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + encoding: []byte{3}, + hash: []byte{4}, + }, + "large leaf encoding": { + leaf: &Leaf{ + Key: repeatBytes(65, 7), + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0}, //nolint:lll + hashDigest: []byte{0xfb, 0xae, 0x31, 0x4b, 0xef, 0x31, 0x9, 0xc7, 0x62, 0x99, 0x9d, 0x40, 0x9b, 0xd4, 0xdc, 0x64, 0xe7, 0x39, 0x46, 0x8b, 0xd3, 0xaf, 0xe8, 0x63, 0x9d, 0xf9, 0x41, 0x40, 0x76, 0x40, 0x10, 0xa3}, //nolint:lll + }, + encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0}, //nolint:lll + hash: []byte{0xfb, 0xae, 0x31, 0x4b, 0xef, 0x31, 0x9, 0xc7, 0x62, 0x99, 0x9d, 0x40, 0x9b, 0xd4, 0xdc, 0x64, 0xe7, 0x39, 0x46, 0x8b, 0xd3, 0xaf, 0xe8, 0x63, 0x9d, 0xf9, 0x41, 0x40, 0x76, 0x40, 0x10, 0xa3}, //nolint:lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, hash, err := testCase.leaf.EncodeAndHash() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encoding, encoding) + assert.Equal(t, testCase.hash, hash) + }) + } +} diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go new file mode 100644 index 0000000000..424d21e307 --- /dev/null +++ b/internal/trie/node/header.go @@ -0,0 +1,67 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "io" +) + +const ( + keyLenOffset = 0x3f +) + +// encodeHeader creates the encoded header for the branch. +func (b *Branch) encodeHeader(writer io.Writer) (err error) { + var header byte + if b.Value == nil { + header = 2 << 6 + } else { + header = 3 << 6 + } + + if len(b.Key) >= keyLenOffset { + header = header | keyLenOffset + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + + err = encodeKeyLength(len(b.Key), writer) + if err != nil { + return err + } + } else { + header = header | byte(len(b.Key)) + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + } + + return nil +} + +// encodeHeader creates the encoded header for the leaf. +func (l *Leaf) encodeHeader(writer io.Writer) (err error) { + var header byte = 1 << 6 + + if len(l.Key) < 63 { + header |= byte(len(l.Key)) + _, err = writer.Write([]byte{header}) + return err + } + + header |= keyLenOffset + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + + err = encodeKeyLength(len(l.Key), writer) + if err != nil { + return err + } + + return nil +} diff --git a/internal/trie/node/header_test.go b/internal/trie/node/header_test.go new file mode 100644 index 0000000000..78f40344ca --- /dev/null +++ b/internal/trie/node/header_test.go @@ -0,0 +1,245 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Test_Branch_encodeHeader(t *testing.T) { + testCases := map[string]struct { + branch *Branch + writes []writeCall + errWrapped error + errMessage string + }{ + "no key": { + branch: &Branch{}, + writes: []writeCall{ + {written: []byte{0x80}}, + }, + }, + "with value": { + branch: &Branch{ + Value: []byte{}, + }, + writes: []writeCall{ + {written: []byte{0xc0}}, + }, + }, + "key of length 30": { + branch: &Branch{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + {written: []byte{0x9e}}, + }, + }, + "key of length 62": { + branch: &Branch{ + Key: make([]byte, 62), + }, + writes: []writeCall{ + {written: []byte{0xbe}}, + }, + }, + "key of length 63": { + branch: &Branch{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + {written: []byte{0xbf}}, + {written: []byte{0x0}}, + }, + }, + "key of length 64": { + branch: &Branch{ + Key: make([]byte, 64), + }, + writes: []writeCall{ + {written: []byte{0xbf}}, + {written: []byte{0x1}}, + }, + }, + "key too big": { + branch: &Branch{ + Key: make([]byte, 65535+63), + }, + writes: []writeCall{ + {written: []byte{0xbf}}, + }, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", + }, + "small key length write error": { + branch: &Branch{}, + writes: []writeCall{ + { + written: []byte{0x80}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: "test error", + }, + "long key length write error": { + branch: &Branch{ + Key: make([]byte, 64), + }, + writes: []writeCall{ + { + written: []byte{0xbf}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: "test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := testCase.branch.encodeHeader(writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_Leaf_encodeHeader(t *testing.T) { + testCases := map[string]struct { + leaf *Leaf + writes []writeCall + errWrapped error + errMessage string + }{ + "no key": { + leaf: &Leaf{}, + writes: []writeCall{ + {written: []byte{0x40}}, + }, + }, + "key of length 30": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + {written: []byte{0x5e}}, + }, + }, + "short key write error": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + { + written: []byte{0x5e}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "key of length 62": { + leaf: &Leaf{ + Key: make([]byte, 62), + }, + writes: []writeCall{ + {written: []byte{0x7e}}, + }, + }, + "key of length 63": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x0}}, + }, + }, + "key of length 64": { + leaf: &Leaf{ + Key: make([]byte, 64), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x1}}, + }, + }, + "long key first byte write error": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + { + written: []byte{0x7f}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "key too big": { + leaf: &Leaf{ + Key: make([]byte, 65535+63), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + }, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := testCase.leaf.encodeHeader(writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go new file mode 100644 index 0000000000..3478ef3aa7 --- /dev/null +++ b/internal/trie/node/key.go @@ -0,0 +1,123 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/pools" +) + +// GetKey returns the key of the branch. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the branch. +func (b *Branch) GetKey() (value []byte) { + return b.Key +} + +// GetKey returns the key of the leaf. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the leaf. +func (l *Leaf) GetKey() (value []byte) { + return l.Key +} + +// SetKey sets the key to the branch. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the branch. +func (b *Branch) SetKey(key []byte) { + b.Key = key +} + +// SetKey sets the key to the leaf. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the leaf. +func (l *Leaf) SetKey(key []byte) { + l.Key = key +} + +const maxPartialKeySize = ^uint16(0) + +var ( + ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") + ErrReadKeyLength = errors.New("cannot read key length") + ErrReadKeyData = errors.New("cannot read key data") +) + +// encodeKeyLength encodes the key length. +func encodeKeyLength(keyLength int, writer io.Writer) (err error) { + keyLength -= 63 + + if keyLength >= int(maxPartialKeySize) { + return fmt.Errorf("%w: %d", + ErrPartialKeyTooBig, keyLength) + } + + for i := uint16(0); i < maxPartialKeySize; i++ { + if keyLength < 255 { + _, err = writer.Write([]byte{byte(keyLength)}) + if err != nil { + return err + } + break + } + _, err = writer.Write([]byte{255}) + if err != nil { + return err + } + + keyLength -= 255 + } + + return nil +} + +// decodeKey decodes a key from a reader. +func decodeKey(reader io.Reader, keyLengthByte byte) (b []byte, err error) { + keyLength := int(keyLengthByte) + + if keyLengthByte == keyLenOffset { + // partial key longer than 63, read next bytes for rest of pk len + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() + for { + _, err = reader.Read(oneByteBuf) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadKeyLength, err) + } + nextKeyLen := oneByteBuf[0] + + keyLength += int(nextKeyLen) + + if nextKeyLen < 0xff { + break + } + + if keyLength >= int(maxPartialKeySize) { + return nil, fmt.Errorf("%w: %d", + ErrPartialKeyTooBig, keyLength) + } + } + } + + if keyLength == 0 { + return []byte{}, nil + } + + key := make([]byte, keyLength/2+keyLength%2) + n, err := reader.Read(key) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadKeyData, err) + } else if n != len(key) { + return nil, fmt.Errorf("%w: read %d bytes instead of %d", + ErrReadKeyData, n, len(key)) + } + + return codec.KeyLEToNibbles(key)[keyLength%2:], nil +} diff --git a/internal/trie/node/key_test.go b/internal/trie/node/key_test.go new file mode 100644 index 0000000000..c3413c1628 --- /dev/null +++ b/internal/trie/node/key_test.go @@ -0,0 +1,334 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Branch_GetKey(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Key: []byte{2}, + } + key := branch.GetKey() + assert.Equal(t, []byte{2}, key) +} + +func Test_Leaf_GetKey(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Key: []byte{2}, + } + key := leaf.GetKey() + assert.Equal(t, []byte{2}, key) +} + +func Test_Branch_SetKey(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Key: []byte{2}, + } + branch.SetKey([]byte{3}) + assert.Equal(t, &Branch{Key: []byte{3}}, branch) +} + +func Test_Leaf_SetKey(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Key: []byte{2}, + } + leaf.SetKey([]byte{3}) + assert.Equal(t, &Leaf{Key: []byte{3}}, leaf) +} + +func repeatBytes(n int, b byte) (slice []byte) { + slice = make([]byte, n) + for i := range slice { + slice[i] = b + } + return slice +} + +func Test_encodeKeyLength(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + keyLength int + writes []writeCall + errWrapped error + errMessage string + }{ + "length equal to maximum": { + keyLength: int(maxPartialKeySize) + 63, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be " + + "larger than or equal to 2^16: 65535", + }, + "zero length": { + writes: []writeCall{ + { + written: []byte{0xc1}, + }, + }, + }, + "one length": { + keyLength: 1, + writes: []writeCall{ + { + written: []byte{0xc2}, + }, + }, + }, + "error at single byte write": { + keyLength: 1, + writes: []writeCall{ + { + written: []byte{0xc2}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "error at first byte write": { + keyLength: 255 + 100 + 63, + writes: []writeCall{ + { + written: []byte{255}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "error at last byte write": { + keyLength: 255 + 100 + 63, + writes: []writeCall{ + { + written: []byte{255}, + }, + { + written: []byte{100}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if write.err != nil { + break + } else if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := encodeKeyLength(testCase.keyLength, writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } + + t.Run("length at maximum", func(t *testing.T) { + t.Parallel() + + // Note: this test case cannot run with the + // mock writer since it's too slow, so we use + // an actual buffer. + + const keyLength = int(maxPartialKeySize) + 62 + const expectedEncodingLength = 257 + expectedBytes := make([]byte, expectedEncodingLength) + for i := 0; i < len(expectedBytes)-1; i++ { + expectedBytes[i] = 255 + } + expectedBytes[len(expectedBytes)-1] = 254 + + buffer := bytes.NewBuffer(nil) + buffer.Grow(expectedEncodingLength) + + err := encodeKeyLength(keyLength, buffer) + + require.NoError(t, err) + assert.Equal(t, expectedBytes, buffer.Bytes()) + }) +} + +//go:generate mockgen -destination=reader_mock_test.go -package $GOPACKAGE io Reader + +type readCall struct { + buffArgCap int + read []byte + n int // number of bytes read + err error +} + +func repeatReadCalls(rc readCall, length int) (readCalls []readCall) { + readCalls = make([]readCall, length) + for i := range readCalls { + readCalls[i] = readCall{ + buffArgCap: rc.buffArgCap, + n: rc.n, + err: rc.err, + } + if rc.read != nil { + readCalls[i].read = make([]byte, len(rc.read)) + copy(readCalls[i].read, rc.read) + } + } + return readCalls +} + +var _ gomock.Matcher = (*byteSliceCapMatcher)(nil) + +type byteSliceCapMatcher struct { + capacity int +} + +func (b *byteSliceCapMatcher) Matches(x interface{}) bool { + slice, ok := x.([]byte) + if !ok { + return false + } + return cap(slice) == b.capacity +} + +func (b *byteSliceCapMatcher) String() string { + return fmt.Sprintf("capacity of slice is not the expected capacity %d", b.capacity) +} + +func newByteSliceCapMatcher(capacity int) *byteSliceCapMatcher { + return &byteSliceCapMatcher{ + capacity: capacity, + } +} + +func Test_decodeKey(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reads []readCall + keyLength byte + b []byte + errWrapped error + errMessage string + }{ + "zero key length": { + b: []byte{}, + }, + "short key length": { + reads: []readCall{ + {buffArgCap: 3, read: []byte{1, 2, 3}, n: 3}, + }, + keyLength: 5, + b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, + }, + "key read error": { + reads: []readCall{ + {buffArgCap: 3, err: errTest}, + }, + keyLength: 5, + errWrapped: ErrReadKeyData, + errMessage: "cannot read key data: test error", + }, + + "key read bytes count mismatch": { + reads: []readCall{ + {buffArgCap: 3, n: 2}, + }, + keyLength: 5, + errWrapped: ErrReadKeyData, + errMessage: "cannot read key data: read 2 bytes instead of 3", + }, + "long key length": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{6}, n: 1}, // key length + {buffArgCap: 35, read: repeatBytes(35, 7), n: 35}, // key data + }, + keyLength: 0x3f, + b: []byte{ + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, + }, + "key length read error": { + reads: []readCall{ + {buffArgCap: 1, err: errTest}, + }, + keyLength: 0x3f, + errWrapped: ErrReadKeyLength, + errMessage: "cannot read key length: test error", + }, + "key length too big": { + reads: repeatReadCalls(readCall{buffArgCap: 1, read: []byte{0xff}, n: 1}, 257), + keyLength: 0x3f, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + reader := NewMockReader(ctrl) + var previousCall *gomock.Call + for _, readCall := range testCase.reads { + byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) + call := reader.EXPECT().Read(byteSliceCapMatcher). + DoAndReturn(func(b []byte) (n int, err error) { + copy(b, readCall.read) + return readCall.n, readCall.err + }) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + b, err := decodeKey(reader, testCase.keyLength) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.b, b) + }) + } +} diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go new file mode 100644 index 0000000000..0de16a3881 --- /dev/null +++ b/internal/trie/node/leaf.go @@ -0,0 +1,48 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "fmt" + "sync" + + "github.com/ChainSafe/gossamer/lib/common" +) + +var _ Node = (*Leaf)(nil) + +// Leaf is a leaf in the trie. +type Leaf struct { + Key []byte // partial key + Value []byte + // Dirty is true when the branch differs + // from the node stored in the database. + dirty bool + hashDigest []byte + encoding []byte + encodingMu sync.RWMutex + // generation is incremented on every trie Snapshot() call. + // Each node also contain a certain generation number, + // which is updated to match the trie generation once they are + // inserted, moved or iterated over. + generation uint64 + sync.RWMutex +} + +// NewLeaf creates a new leaf using the arguments given. +func NewLeaf(key, value []byte, dirty bool, generation uint64) *Leaf { + return &Leaf{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } +} + +func (l *Leaf) String() string { + if len(l.Value) > 1024 { + return fmt.Sprintf("leaf key=0x%x value (hashed)=0x%x dirty=%t", l.Key, common.MustBlake2bHash(l.Value), l.dirty) + } + return fmt.Sprintf("leaf key=0x%x value=0x%x dirty=%t", l.Key, l.Value, l.dirty) +} diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go new file mode 100644 index 0000000000..f18bbf0d28 --- /dev/null +++ b/internal/trie/node/leaf_encode.go @@ -0,0 +1,117 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "bytes" + "fmt" + "hash" + "io" + + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/pools" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// Encode encodes a leaf to the buffer given. +// The encoding has the following format: +// NodeHeader | Extra partial key length | Partial Key | Value +func (l *Leaf) Encode(buffer Buffer) (err error) { + l.encodingMu.RLock() + if !l.dirty && l.encoding != nil { + _, err = buffer.Write(l.encoding) + l.encodingMu.RUnlock() + if err != nil { + return fmt.Errorf("cannot write stored encoding to buffer: %w", err) + } + return nil + } + l.encodingMu.RUnlock() + + err = l.encodeHeader(buffer) + if err != nil { + return fmt.Errorf("cannot encode header: %w", err) + } + + keyLE := codec.NibblesToKeyLE(l.Key) + _, err = buffer.Write(keyLE) + if err != nil { + return fmt.Errorf("cannot write LE key to buffer: %w", err) + } + + encodedValue, err := scale.Marshal(l.Value) // TODO scale encoder to write to buffer + if err != nil { + return fmt.Errorf("cannot scale marshal value: %w", err) + } + + _, err = buffer.Write(encodedValue) + if err != nil { + return fmt.Errorf("cannot write scale encoded value to buffer: %w", err) + } + + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.encodingMu.Lock() + defer l.encodingMu.Unlock() + l.encoding = make([]byte, buffer.Len()) + copy(l.encoding, buffer.Bytes()) + return nil +} + +// ScaleEncodeHash hashes the node (blake2b sum on encoded value) +// and then SCALE encodes it. This is used to encode children +// nodes of branches. +func (l *Leaf) ScaleEncodeHash() (encoding []byte, err error) { + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.DigestBuffers.Put(buffer) + + err = l.hash(buffer) + if err != nil { + return nil, fmt.Errorf("cannot hash leaf: %w", err) + } + + scEncChild, err := scale.Marshal(buffer.Bytes()) + if err != nil { + return nil, fmt.Errorf("cannot scale encode hashed leaf: %w", err) + } + return scEncChild, nil +} + +func (l *Leaf) hash(writer io.Writer) (err error) { + encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + encodingBuffer.Reset() + defer pools.EncodingBuffers.Put(encodingBuffer) + + err = l.Encode(encodingBuffer) + if err != nil { + return fmt.Errorf("cannot encode leaf: %w", err) + } + + // if length of encoded leaf is less than 32 bytes, do not hash + if encodingBuffer.Len() < 32 { + _, err = writer.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot write encoded leaf to buffer: %w", err) + } + return nil + } + + // otherwise, hash encoded node + hasher := pools.Hashers.Get().(hash.Hash) + hasher.Reset() + defer pools.Hashers.Put(hasher) + + // Note: using the sync.Pool's buffer is useful here. + _, err = hasher.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot hash encoded node: %w", err) + } + + _, err = writer.Write(hasher.Sum(nil)) + if err != nil { + return fmt.Errorf("cannot write hash sum of leaf to buffer: %w", err) + } + return nil +} diff --git a/internal/trie/node/leaf_encode_test.go b/internal/trie/node/leaf_encode_test.go new file mode 100644 index 0000000000..61eb78ad9b --- /dev/null +++ b/internal/trie/node/leaf_encode_test.go @@ -0,0 +1,296 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Leaf_Encode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + writes []writeCall + bufferLenCall bool + bufferBytesCall bool + bufferBytes []byte + expectedEncoding []byte + wrappedErr error + errMessage string + }{ + "clean leaf with encoding": { + leaf: &Leaf{ + encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{1, 2, 3}, + }, + }, + expectedEncoding: []byte{1, 2, 3}, + }, + "write error for clean leaf with encoding": { + leaf: &Leaf{ + encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{1, 2, 3}, + err: errTest, + }, + }, + expectedEncoding: []byte{1, 2, 3}, + wrappedErr: errTest, + errMessage: "cannot write stored encoding to buffer: test error", + }, + "header encoding error": { + leaf: &Leaf{ + Key: make([]byte, 63+(1<<16)), + }, + writes: []writeCall{ + { + written: []byte{127}, + }, + }, + wrappedErr: ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + }, + "buffer write error for encoded key": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write LE key to buffer: test error", + }, + "buffer write error for encoded value": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + Value: []byte{4, 5, 6}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + }, + { + written: []byte{12, 4, 5, 6}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write scale encoded value to buffer: test error", + }, + "success": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + Value: []byte{4, 5, 6}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + }, + { + written: []byte{12, 4, 5, 6}, + }, + }, + bufferLenCall: true, + bufferBytesCall: true, + bufferBytes: []byte{1, 2, 3}, + expectedEncoding: []byte{1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockBuffer(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + if testCase.bufferLenCall { + buffer.EXPECT().Len().Return(len(testCase.bufferBytes)) + } + if testCase.bufferBytesCall { + buffer.EXPECT().Bytes().Return(testCase.bufferBytes) + } + + err := testCase.leaf.Encode(buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + assert.Equal(t, testCase.expectedEncoding, testCase.leaf.encoding) + }) + } +} + +func Test_Leaf_ScaleEncodeHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + b []byte + wrappedErr error + errMessage string + }{ + "leaf": { + leaf: &Leaf{}, + b: []byte{0x8, 0x40, 0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + b, err := testCase.leaf.ScaleEncodeHash() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + + assert.Equal(t, testCase.b, b) + }) + } +} + +func Test_Leaf_hash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + writeCall bool + write writeCall + wrappedErr error + errMessage string + }{ + "small leaf buffer write error": { + leaf: &Leaf{ + encoding: []byte{1, 2, 3}, + }, + writeCall: true, + write: writeCall{ + written: []byte{1, 2, 3}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded leaf to buffer: " + + "test error", + }, + "small leaf success": { + leaf: &Leaf{ + encoding: []byte{1, 2, 3}, + }, + writeCall: true, + write: writeCall{ + written: []byte{1, 2, 3}, + }, + }, + "leaf hash sum buffer write error": { + leaf: &Leaf{ + encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{ + 107, 105, 154, 175, 253, 170, 232, + 135, 240, 21, 207, 148, 82, 117, + 249, 230, 80, 197, 254, 17, 149, + 108, 50, 7, 80, 56, 114, 176, + 84, 114, 125, 234}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "cannot write hash sum of leaf to buffer: " + + "test error", + }, + "leaf hash sum success": { + leaf: &Leaf{ + encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{ + 107, 105, 154, 175, 253, 170, 232, + 135, 240, 21, 207, 148, 82, 117, + 249, 230, 80, 197, 254, 17, 149, + 108, 50, 7, 80, 56, 114, 176, + 84, 114, 125, 234}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + if testCase.writeCall { + writer.EXPECT(). + Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + } + + err := testCase.leaf.hash(writer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/trie/node/leaf_test.go b/internal/trie/node/leaf_test.go new file mode 100644 index 0000000000..d755eb724d --- /dev/null +++ b/internal/trie/node/leaf_test.go @@ -0,0 +1,77 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewLeaf(t *testing.T) { + t.Parallel() + + key := []byte{1, 2} + value := []byte{3, 4} + const dirty = true + const generation = 9 + + leaf := NewLeaf(key, value, dirty, generation) + + expectedLeaf := &Leaf{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } + assert.Equal(t, expectedLeaf, leaf) + + // Check modifying passed slice modifies leaf slices + key[0] = 11 + value[0] = 13 + assert.Equal(t, expectedLeaf, leaf) +} + +func Test_Leaf_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + s string + }{ + "empty leaf": { + leaf: &Leaf{}, + s: "leaf key=0x value=0x dirty=false", + }, + "leaf with value smaller than 1024": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + }, + s: "leaf key=0x0102 value=0x0304 dirty=true", + }, + "leaf with value higher than 1024": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: make([]byte, 1025), + dirty: true, + }, + s: "leaf key=0x0102 " + + "value (hashed)=0x307861663233363133353361303538646238383034626337353735323831663131663735313265326331346336373032393864306232336630396538386266333066 " + //nolint:lll + "dirty=true", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.leaf.String() + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go new file mode 100644 index 0000000000..6c306979bb --- /dev/null +++ b/internal/trie/node/node.go @@ -0,0 +1,22 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// Node is a node in the trie and can be a leaf or a branch. +type Node interface { + Encode(buffer Buffer) (err error) // TODO change to io.Writer + EncodeAndHash() (encoding []byte, hash []byte, err error) + ScaleEncodeHash() (encoding []byte, err error) + IsDirty() bool + SetDirty(dirty bool) + SetKey(key []byte) + String() string + SetEncodingAndHash(encoding []byte, hash []byte) + GetHash() (hash []byte) + GetKey() (key []byte) + GetValue() (value []byte) + GetGeneration() (generation uint64) + SetGeneration(generation uint64) + Copy() Node +} diff --git a/internal/trie/node/reader_mock_test.go b/internal/trie/node/reader_mock_test.go new file mode 100644 index 0000000000..2aa28d2998 --- /dev/null +++ b/internal/trie/node/reader_mock_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Reader) + +// Package node is a generated GoMock package. +package node + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockReader is a mock of Reader interface. +type MockReader struct { + ctrl *gomock.Controller + recorder *MockReaderMockRecorder +} + +// MockReaderMockRecorder is the mock recorder for MockReader. +type MockReaderMockRecorder struct { + mock *MockReader +} + +// NewMockReader creates a new mock instance. +func NewMockReader(ctrl *gomock.Controller) *MockReader { + mock := &MockReader{ctrl: ctrl} + mock.recorder = &MockReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReader) EXPECT() *MockReaderMockRecorder { + return m.recorder +} + +// Read mocks base method. +func (m *MockReader) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockReaderMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), arg0) +} diff --git a/internal/trie/node/types.go b/internal/trie/node/types.go new file mode 100644 index 0000000000..5f0ef8191b --- /dev/null +++ b/internal/trie/node/types.go @@ -0,0 +1,17 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// Type is the byte type for the node. +type Type byte + +const ( + _ Type = iota + // LeafType type is 1 + LeafType + // BranchType type is 2 + BranchType + // BranchWithValueType type is 3 + BranchWithValueType +) diff --git a/internal/trie/node/value.go b/internal/trie/node/value.go new file mode 100644 index 0000000000..5ab07fb589 --- /dev/null +++ b/internal/trie/node/value.go @@ -0,0 +1,18 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// GetValue returns the value of the branch. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the branch. +func (b *Branch) GetValue() (value []byte) { + return b.Value +} + +// GetValue returns the value of the leaf. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the leaf. +func (l *Leaf) GetValue() (value []byte) { + return l.Value +} diff --git a/internal/trie/node/value_test.go b/internal/trie/node/value_test.go new file mode 100644 index 0000000000..f6fe989d1d --- /dev/null +++ b/internal/trie/node/value_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_GetValue(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Value: []byte{2}, + } + value := branch.GetValue() + assert.Equal(t, []byte{2}, value) +} + +func Test_Leaf_GetValue(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Value: []byte{2}, + } + value := leaf.GetValue() + assert.Equal(t, []byte{2}, value) +} diff --git a/lib/trie/writer_mock_test.go b/internal/trie/node/writer_mock_test.go similarity index 95% rename from lib/trie/writer_mock_test.go rename to internal/trie/node/writer_mock_test.go index b1009272f2..9665f01c85 100644 --- a/lib/trie/writer_mock_test.go +++ b/internal/trie/node/writer_mock_test.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: io (interfaces: Writer) -// Package trie is a generated GoMock package. -package trie +// Package node is a generated GoMock package. +package node import ( reflect "reflect" diff --git a/internal/trie/pools/pools.go b/internal/trie/pools/pools.go new file mode 100644 index 0000000000..855232ef44 --- /dev/null +++ b/internal/trie/pools/pools.go @@ -0,0 +1,51 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package pools + +import ( + "bytes" + "sync" + + "golang.org/x/crypto/blake2b" +) + +// SingleByteBuffers is a sync pool of buffers of capacity 1. +var SingleByteBuffers = &sync.Pool{ + New: func() interface{} { + const bufferLength = 1 + b := make([]byte, bufferLength) + return bytes.NewBuffer(b) + }, +} + +// DigestBuffers is a sync pool of buffers of capacity 32. +var DigestBuffers = &sync.Pool{ + New: func() interface{} { + const bufferCapacity = 32 + b := make([]byte, 0, bufferCapacity) + return bytes.NewBuffer(b) + }, +} + +// EncodingBuffers is a sync pool of buffers of capacity 1.9MB. +var EncodingBuffers = &sync.Pool{ + New: func() interface{} { + const initialBufferCapacity = 1900000 // 1.9MB, from checking capacities at runtime + b := make([]byte, 0, initialBufferCapacity) + return bytes.NewBuffer(b) + }, +} + +// Hashers is a sync pool of blake2b 256 hashers. +var Hashers = &sync.Pool{ + New: func() interface{} { + hasher, err := blake2b.New256(nil) + if err != nil { + // Conversation on why we panic here: + // https://github.com/ChainSafe/gossamer/pull/2009#discussion_r753430764 + panic("cannot create Blake2b-256 hasher: " + err.Error()) + } + return hasher + }, +} diff --git a/internal/trie/record/node.go b/internal/trie/record/node.go new file mode 100644 index 0000000000..19a745c82c --- /dev/null +++ b/internal/trie/record/node.go @@ -0,0 +1,10 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package record + +// Node represents a record of a visited node +type Node struct { + RawData []byte + Hash []byte +} diff --git a/internal/trie/record/recorder.go b/internal/trie/record/recorder.go new file mode 100644 index 0000000000..130b434338 --- /dev/null +++ b/internal/trie/record/recorder.go @@ -0,0 +1,27 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package record + +// Recorder records the list of nodes found by Lookup.Find +type Recorder struct { + nodes []Node +} + +// NewRecorder creates a new recorder. +func NewRecorder() *Recorder { + return &Recorder{} +} + +// Record appends a node to the list of visited nodes. +func (r *Recorder) Record(hash, rawData []byte) { + r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash}) +} + +// GetNodes returns all the nodes recorded. +// Note it does not copy its slice of nodes. +// It's fine to not copy them since the recorder +// is not used again after a call to GetNodes() +func (r *Recorder) GetNodes() (nodes []Node) { + return r.nodes +} diff --git a/internal/trie/record/recorder_test.go b/internal/trie/record/recorder_test.go new file mode 100644 index 0000000000..943f82859d --- /dev/null +++ b/internal/trie/record/recorder_test.go @@ -0,0 +1,118 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package record + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewRecorder(t *testing.T) { + t.Parallel() + + expected := &Recorder{} + + recorder := NewRecorder() + + assert.Equal(t, expected, recorder) +} + +func Test_Recorder_Record(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + hash []byte + rawData []byte + expectedRecorder *Recorder + }{ + "nil data": { + recorder: &Recorder{}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {}, + }, + }, + }, + "insert in empty recorder": { + recorder: &Recorder{}, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + "insert in non-empty recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + }, + }, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.recorder.Record(testCase.hash, testCase.rawData) + + assert.Equal(t, testCase.expectedRecorder, testCase.recorder) + }) + } +} + +func Test_Recorder_GetNodes(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + nodes []Node + }{ + "no node": { + recorder: &Recorder{}, + }, + "get single node from recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}}, + }, + "get node from multiple nodes in recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodes := testCase.recorder.GetNodes() + + assert.Equal(t, testCase.nodes, nodes) + }) + } +} diff --git a/lib/trie/bytesBuffer_mock_test.go b/lib/trie/bytesBuffer_mock_test.go deleted file mode 100644 index c59f7dd4a9..0000000000 --- a/lib/trie/bytesBuffer_mock_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: hash.go - -// Package trie is a generated GoMock package. -package trie - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockbytesBuffer is a mock of bytesBuffer interface. -type MockbytesBuffer struct { - ctrl *gomock.Controller - recorder *MockbytesBufferMockRecorder -} - -// MockbytesBufferMockRecorder is the mock recorder for MockbytesBuffer. -type MockbytesBufferMockRecorder struct { - mock *MockbytesBuffer -} - -// NewMockbytesBuffer creates a new mock instance. -func NewMockbytesBuffer(ctrl *gomock.Controller) *MockbytesBuffer { - mock := &MockbytesBuffer{ctrl: ctrl} - mock.recorder = &MockbytesBufferMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockbytesBuffer) EXPECT() *MockbytesBufferMockRecorder { - return m.recorder -} - -// Bytes mocks base method. -func (m *MockbytesBuffer) Bytes() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Bytes") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Bytes indicates an expected call of Bytes. -func (mr *MockbytesBufferMockRecorder) Bytes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockbytesBuffer)(nil).Bytes)) -} - -// Len mocks base method. -func (m *MockbytesBuffer) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockbytesBufferMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockbytesBuffer)(nil).Len)) -} - -// Write mocks base method. -func (m *MockbytesBuffer) Write(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockbytesBufferMockRecorder) Write(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockbytesBuffer)(nil).Write), p) -} diff --git a/lib/trie/codec.go b/lib/trie/codec.go deleted file mode 100644 index 33bad34007..0000000000 --- a/lib/trie/codec.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -// keyToNibbles turns bytes into nibbles -// does not rearrange the nibbles; assumes they are already ordered in LE -func keyToNibbles(in []byte) []byte { - if len(in) == 0 { - return []byte{} - } else if len(in) == 1 && in[0] == 0 { - return []byte{0, 0} - } - - l := len(in) * 2 - res := make([]byte, l) - for i, b := range in { - res[2*i] = b / 16 - res[2*i+1] = b % 16 - } - - return res -} - -// nibblesToKey turns a slice of nibbles w/ length k into a big endian byte array -// if the length of the input is odd, the result is [ in[1] in[0] | ... | 0000 in[k-1] ] -// otherwise, res = [ in[1] in[0] | ... | in[k-1] in[k-2] ] -func nibblesToKey(in []byte) (res []byte) { - if len(in)%2 == 0 { - res = make([]byte, len(in)/2) - for i := 0; i < len(in); i += 2 { - res[i/2] = (in[i] & 0xf) | (in[i+1] << 4 & 0xf0) - } - } else { - res = make([]byte, len(in)/2+1) - for i := 0; i < len(in); i += 2 { - if i < len(in)-1 { - res[i/2] = (in[i] & 0xf) | (in[i+1] << 4 & 0xf0) - } else { - res[i/2] = (in[i] & 0xf) - } - } - } - - return res -} - -// nibblesToKey turns a slice of nibbles w/ length k into a little endian byte array -// assumes nibbles are already LE, does not rearrange nibbles -// if the length of the input is odd, the result is [ 0000 in[0] | in[1] in[2] | ... | in[k-2] in[k-1] ] -// otherwise, res = [ in[0] in[1] | ... | in[k-2] in[k-1] ] -func nibblesToKeyLE(in []byte) (res []byte) { - if len(in)%2 == 0 { - res = make([]byte, len(in)/2) - for i := 0; i < len(in); i += 2 { - res[i/2] = (in[i] << 4 & 0xf0) | (in[i+1] & 0xf) - } - } else { - res = make([]byte, len(in)/2+1) - res[0] = in[0] - for i := 2; i < len(in); i += 2 { - res[i/2] = (in[i-1] << 4 & 0xf0) | (in[i] & 0xf) - } - } - - return res -} diff --git a/lib/trie/codec_test.go b/lib/trie/codec_test.go deleted file mode 100644 index 108a5acfa7..0000000000 --- a/lib/trie/codec_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "fmt" - "testing" -) - -func TestKeyToNibbles(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0x0}, []byte{0, 0}}, - {[]byte{0xFF}, []byte{0xF, 0xF}}, - {[]byte{0x3a, 0x05}, []byte{0x3, 0xa, 0x0, 0x5}}, - {[]byte{0xAA, 0xFF, 0x01}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, - {[]byte{0xAA, 0xFF, 0x01, 0xc2}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, - {[]byte{0xAA, 0xFF, 0x01, 0xc0}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := keyToNibbles(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%v expected=%v\n", res, test.expected) - } - }) - } -} - -func TestNibblesToKey(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0xF, 0xF}, []byte{0xFF}}, - {[]byte{0x3, 0xa, 0x0, 0x5}, []byte{0xa3, 0x50}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, []byte{0xaa, 0xff, 0x10}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, []byte{0xaa, 0xff, 0x10, 0x2c}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, []byte{0xaa, 0xff, 0x10, 0x0c}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := nibblesToKey(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%x expected=%x\n", res, test.expected) - } - }) - } -} - -func TestNibblesToKeyLE(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0xF, 0xF}, []byte{0xFF}}, - {[]byte{0x3, 0xa, 0x0, 0x5}, []byte{0x3a, 0x05}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, []byte{0xaa, 0xff, 0x01}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, []byte{0xaa, 0xff, 0x01, 0xc2}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, []byte{0xa, 0xaf, 0xf0, 0x1c}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := nibblesToKeyLE(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%x expected=%x\n", res, test.expected) - } - }) - } -} diff --git a/lib/trie/database.go b/lib/trie/database.go index 528664d078..f80c2d7b65 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/chaindb" @@ -31,12 +33,12 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, curr node) error { +func (t *Trie) store(db chaindb.Batch, curr Node) error { if curr == nil { return nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return err } @@ -46,8 +48,8 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { return err } - if c, ok := curr.(*branch); ok { - for _, child := range c.children { + if c, ok := curr.(*node.Branch); ok { + for _, child := range c.Children { if child == nil { continue } @@ -59,8 +61,8 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { } } - if curr.isDirty() { - curr.setDirty(false) + if curr.IsDirty() { + curr.SetDirty(false) } return nil @@ -72,20 +74,20 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrEmptyProof } - mappedNodes := make(map[string]node, len(proof)) + mappedNodes := make(map[string]Node, len(proof)) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - decNode, err := decodeBytes(rawNode) + decNode, err := node.Decode(bytes.NewReader(rawNode)) if err != nil { return err } - decNode.setDirty(false) - decNode.setEncodingAndHash(rawNode, nil) + decNode.SetDirty(false) + decNode.SetEncodingAndHash(rawNode, nil) - _, computedRoot, err := decNode.encodeAndHash() + _, computedRoot, err := decNode.EncodeAndHash() if err != nil { return err } @@ -103,23 +105,23 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadProof(proof map[string]node, curr node) { - c, ok := curr.(*branch) +func (t *Trie) loadProof(proof map[string]Node, curr Node) { + c, ok := curr.(*node.Branch) if !ok { return } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } - proofNode, ok := proof[common.BytesToHex(child.getHash())] + proofNode, ok := proof[common.BytesToHex(child.GetHash())] if !ok { continue } - c.children[i] = proofNode + c.Children[i] = proofNode t.loadProof(proof, proofNode) } } @@ -137,39 +139,39 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return fmt.Errorf("failed to find root key=%s: %w", root, err) } - t.root, err = decodeBytes(enc) + t.root, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } - t.root.setDirty(false) - t.root.setEncodingAndHash(enc, root[:]) + t.root.SetDirty(false) + t.root.SetEncodingAndHash(enc, root[:]) return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, curr node) error { - if c, ok := curr.(*branch); ok { - for i, child := range c.children { +func (t *Trie) load(db chaindb.Database, curr Node) error { + if c, ok := curr.(*node.Branch); ok { + for i, child := range c.Children { if child == nil { continue } - hash := child.getHash() + hash := child.GetHash() enc, err := db.Get(hash) if err != nil { - return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*leaf).hash, i, err) + return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) } - child, err = decodeBytes(enc) + child, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } - child.setDirty(false) - child.setEncodingAndHash(enc, hash) + child.SetDirty(false) + child.SetEncodingAndHash(enc, hash) - c.children[i] = child + c.Children[i] = child err = t.load(db, child) if err != nil { return err @@ -181,14 +183,14 @@ func (t *Trie) load(db chaindb.Database, curr node) error { } // GetNodeHashes return hash of each key of the trie. -func (t *Trie) GetNodeHashes(curr node, keys map[common.Hash]struct{}) error { - if c, ok := curr.(*branch); ok { - for _, child := range c.children { +func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { + if c, ok := curr.(*node.Branch); ok { + for _, child := range c.Children { if child == nil { continue } - hash := child.getHash() + hash := child.GetHash() keys[common.BytesToHash(hash)] = struct{}{} err := t.GetNodeHashes(child, keys) @@ -234,14 +236,14 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, nil } - k := keyToNibbles(key) + k := codec.KeyLEToNibbles(key) enc, err := db.Get(root[:]) if err != nil { return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) } - rootNode, err := decodeBytes(enc) + rootNode, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } @@ -249,34 +251,34 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return getFromDB(db, rootNode, k) } -func getFromDB(db chaindb.Database, parent node, key []byte) ([]byte, error) { +func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { var value []byte switch p := parent.(type) { - case *branch: - length := lenCommonPrefix(p.key, key) + case *node.Branch: + length := lenCommonPrefix(p.Key, key) // found the value at this node - if bytes.Equal(p.key, key) || len(key) == 0 { - return p.value, nil + if bytes.Equal(p.Key, key) || len(key) == 0 { + return p.Value, nil } // did not find value - if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { + if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { return nil, nil } - if p.children[key[length]] == nil { + if p.Children[key[length]] == nil { return nil, nil } // load child with potential value - enc, err := db.Get(p.children[key[length]].(*leaf).hash) + enc, err := db.Get(p.Children[key[length]].GetHash()) if err != nil { return nil, fmt.Errorf("failed to find node in database: %w", err) } - child, err := decodeBytes(enc) + child, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } @@ -285,9 +287,9 @@ func getFromDB(db chaindb.Database, parent node, key []byte) ([]byte, error) { if err != nil { return nil, err } - case *leaf: - if bytes.Equal(p.key, key) { - return p.value, nil + case *node.Leaf: + if bytes.Equal(p.Key, key) { + return p.Value, nil } case nil: return nil, nil @@ -308,12 +310,12 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { - if curr == nil || !curr.isDirty() { +func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { + if curr == nil || !curr.IsDirty() { return nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return err } @@ -333,8 +335,8 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { return err } - if c, ok := curr.(*branch); ok { - for _, child := range c.children { + if c, ok := curr.(*node.Branch); ok { + for _, child := range c.Children { if child == nil { continue } @@ -346,7 +348,7 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { } } - curr.setDirty(false) + curr.SetDirty(false) return nil } @@ -356,13 +358,13 @@ func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { return t.getInsertedNodeHashes(t.root) } -func (t *Trie) getInsertedNodeHashes(curr node) ([]common.Hash, error) { +func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { var nodeHashes []common.Hash - if curr == nil || !curr.isDirty() { + if curr == nil || !curr.IsDirty() { return nil, nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return nil, err } @@ -379,8 +381,8 @@ func (t *Trie) getInsertedNodeHashes(curr node) ([]common.Hash, error) { nodeHash := common.BytesToHash(hash) nodeHashes = append(nodeHashes, nodeHash) - if c, ok := curr.(*branch); ok { - for _, child := range c.children { + if c, ok := curr.(*node.Branch); ok { + for _, child := range c.Children { if child == nil { continue } diff --git a/lib/trie/hash.go b/lib/trie/hash.go deleted file mode 100644 index ecb674ea82..0000000000 --- a/lib/trie/hash.go +++ /dev/null @@ -1,348 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "errors" - "fmt" - "hash" - "io" - "sync" - - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" - "golang.org/x/crypto/blake2b" -) - -var encodingBufferPool = &sync.Pool{ - New: func() interface{} { - const initialBufferCapacity = 1900000 // 1.9MB, from checking capacities at runtime - b := make([]byte, 0, initialBufferCapacity) - return bytes.NewBuffer(b) - }, -} - -var digestBufferPool = &sync.Pool{ - New: func() interface{} { - const bufferCapacity = 32 - b := make([]byte, 0, bufferCapacity) - return bytes.NewBuffer(b) - }, -} - -var hasherPool = &sync.Pool{ - New: func() interface{} { - hasher, err := blake2b.New256(nil) - if err != nil { - panic("cannot create Blake2b-256 hasher: " + err.Error()) - } - return hasher - }, -} - -func hashNode(n node, digestBuffer io.Writer) (err error) { - encodingBuffer := encodingBufferPool.Get().(*bytes.Buffer) - encodingBuffer.Reset() - defer encodingBufferPool.Put(encodingBuffer) - - const parallel = false - - err = encodeNode(n, encodingBuffer, parallel) - if err != nil { - return fmt.Errorf("cannot encode node: %w", err) - } - - // if length of encoded leaf is less than 32 bytes, do not hash - if encodingBuffer.Len() < 32 { - _, err = digestBuffer.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot write encoded node to buffer: %w", err) - } - return nil - } - - // otherwise, hash encoded node - hasher := hasherPool.Get().(hash.Hash) - hasher.Reset() - defer hasherPool.Put(hasher) - - // Note: using the sync.Pool's buffer is useful here. - _, err = hasher.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot hash encoded node: %w", err) - } - - _, err = digestBuffer.Write(hasher.Sum(nil)) - if err != nil { - return fmt.Errorf("cannot write hash sum of node to buffer: %w", err) - } - return nil -} - -var ErrNodeTypeUnsupported = errors.New("node type is not supported") - -type bytesBuffer interface { - // note: cannot compose with io.Writer for mock generation - Write(p []byte) (n int, err error) - Len() int - Bytes() []byte -} - -// encodeNode writes the encoding of the node to the buffer given. -// It is the high-level function wrapping the encoding for different -// node types. The encoding has the following format: -// NodeHeader | Extra partial key length | Partial Key | Value -func encodeNode(n node, buffer bytesBuffer, parallel bool) (err error) { - switch n := n.(type) { - case *branch: - err := encodeBranch(n, buffer, parallel) - if err != nil { - return fmt.Errorf("cannot encode branch: %w", err) - } - return nil - case *leaf: - err := encodeLeaf(n, buffer) - if err != nil { - return fmt.Errorf("cannot encode leaf: %w", err) - } - - n.encodingMu.Lock() - defer n.encodingMu.Unlock() - - // TODO remove this copying since it defeats the purpose of `buffer` - // and the sync.Pool. - n.encoding = make([]byte, buffer.Len()) - copy(n.encoding, buffer.Bytes()) - return nil - case nil: - _, err := buffer.Write([]byte{0}) - if err != nil { - return fmt.Errorf("cannot encode nil node: %w", err) - } - return nil - default: - return fmt.Errorf("%w: %T", ErrNodeTypeUnsupported, n) - } -} - -// encodeBranch encodes a branch with the encoding specified at the top of this package -// to the buffer given. -func encodeBranch(b *branch, buffer io.Writer, parallel bool) (err error) { - if !b.dirty && b.encoding != nil { - _, err = buffer.Write(b.encoding) - if err != nil { - return fmt.Errorf("cannot write stored encoding to buffer: %w", err) - } - return nil - } - - encodedHeader, err := b.header() - if err != nil { - return fmt.Errorf("cannot encode header: %w", err) - } - - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - - keyLE := nibblesToKeyLE(b.key) - _, err = buffer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write encoded key to buffer: %w", err) - } - - childrenBitmap := common.Uint16ToBytes(b.childrenBitmap()) - _, err = buffer.Write(childrenBitmap) - if err != nil { - return fmt.Errorf("cannot write children bitmap to buffer: %w", err) - } - - if b.value != nil { - bytes, err := scale.Marshal(b.value) - if err != nil { - return fmt.Errorf("cannot scale encode value: %w", err) - } - - _, err = buffer.Write(bytes) - if err != nil { - return fmt.Errorf("cannot write encoded value to buffer: %w", err) - } - } - - if parallel { - err = encodeChildrenInParallel(b.children, buffer) - } else { - err = encodeChildrenSequentially(b.children, buffer) - } - if err != nil { - return fmt.Errorf("cannot encode children of branch: %w", err) - } - - return nil -} - -func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) { - type result struct { - index int - buffer *bytes.Buffer - err error - } - - resultsCh := make(chan result) - - for i, child := range children { - go func(index int, child node) { - buffer := encodingBufferPool.Get().(*bytes.Buffer) - buffer.Reset() - // buffer is put back in the pool after processing its - // data in the select block below. - - err := encodeChild(child, buffer) - - resultsCh <- result{ - index: index, - buffer: buffer, - err: err, - } - }(i, child) - } - - currentIndex := 0 - resultBuffers := make([]*bytes.Buffer, len(children)) - for range children { - result := <-resultsCh - if result.err != nil && err == nil { // only set the first error we get - err = result.err - } - - resultBuffers[result.index] = result.buffer - - // write as many completed buffers to the result buffer. - for currentIndex < len(children) && - resultBuffers[currentIndex] != nil { - bufferSlice := resultBuffers[currentIndex].Bytes() - if len(bufferSlice) > 0 { - // note buffer.Write copies the byte slice given as argument - _, writeErr := buffer.Write(bufferSlice) - if writeErr != nil && err == nil { - err = fmt.Errorf( - "cannot write encoding of child at index %d: %w", - currentIndex, writeErr) - } - } - - encodingBufferPool.Put(resultBuffers[currentIndex]) - resultBuffers[currentIndex] = nil - - currentIndex++ - } - } - - for _, buffer := range resultBuffers { - if buffer == nil { // already emptied and put back in pool - continue - } - encodingBufferPool.Put(buffer) - } - - return err -} - -func encodeChildrenSequentially(children [16]node, buffer io.Writer) (err error) { - for i, child := range children { - err = encodeChild(child, buffer) - if err != nil { - return fmt.Errorf("cannot encode child at index %d: %w", i, err) - } - } - return nil -} - -func encodeChild(child node, buffer io.Writer) (err error) { - var isNil bool - switch impl := child.(type) { - case *branch: - isNil = impl == nil - case *leaf: - isNil = impl == nil - default: - isNil = child == nil - } - if isNil { - return nil - } - - scaleEncodedChild, err := encodeAndHash(child) - if err != nil { - return fmt.Errorf("failed to hash and scale encode child: %w", err) - } - - _, err = buffer.Write(scaleEncodedChild) - if err != nil { - return fmt.Errorf("failed to write child to buffer: %w", err) - } - - return nil -} - -func encodeAndHash(n node) (b []byte, err error) { - buffer := digestBufferPool.Get().(*bytes.Buffer) - buffer.Reset() - defer digestBufferPool.Put(buffer) - - err = hashNode(n, buffer) - if err != nil { - return nil, fmt.Errorf("cannot hash node: %w", err) - } - - scEncChild, err := scale.Marshal(buffer.Bytes()) - if err != nil { - return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) - } - return scEncChild, nil -} - -// encodeLeaf encodes a leaf to the buffer given, with the encoding -// specified at the top of this package. -func encodeLeaf(l *leaf, buffer io.Writer) (err error) { - l.encodingMu.RLock() - defer l.encodingMu.RUnlock() - if !l.dirty && l.encoding != nil { - _, err = buffer.Write(l.encoding) - if err != nil { - return fmt.Errorf("cannot write stored encoding to buffer: %w", err) - } - return nil - } - - encodedHeader, err := l.header() - if err != nil { - return fmt.Errorf("cannot encode header: %w", err) - } - - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - - keyLE := nibblesToKeyLE(l.key) - _, err = buffer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write LE key to buffer: %w", err) - } - - encodedValue, err := scale.Marshal(l.value) // TODO scale encoder to write to buffer - if err != nil { - return fmt.Errorf("cannot scale marshal value: %w", err) - } - - _, err = buffer.Write(encodedValue) - if err != nil { - return fmt.Errorf("cannot write scale encoded value to buffer: %w", err) - } - - return nil -} diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go deleted file mode 100644 index 23265fdf57..0000000000 --- a/lib/trie/hash_test.go +++ /dev/null @@ -1,1012 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "errors" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type writeCall struct { - written []byte - n int - err error -} - -var errTest = errors.New("test error") - -//go:generate mockgen -destination=bytesBuffer_mock_test.go -package $GOPACKAGE -source=hash.go . bytesBuffer -//go:generate mockgen -destination=node_mock_test.go -package $GOPACKAGE -source=node.go . node - -func Test_hashNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n node - writeCall bool - write writeCall - wrappedErr error - errMessage string - }{ - "node encoding error": { - n: NewMocknode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "cannot encode node: " + - "node type is not supported: " + - "*trie.Mocknode", - }, - "small leaf buffer write error": { - n: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writeCall: true, - write: writeCall{ - written: []byte{1, 2, 3}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded node to buffer: " + - "test error", - }, - "small leaf success": { - n: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writeCall: true, - write: writeCall{ - written: []byte{1, 2, 3}, - }, - }, - "leaf hash sum buffer write error": { - n: &leaf{ - encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "cannot write hash sum of node to buffer: " + - "test error", - }, - "leaf hash sum success": { - n: &leaf{ - encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockWriter(ctrl) - if testCase.writeCall { - buffer.EXPECT(). - Write(testCase.write.written). - Return(testCase.write.n, testCase.write.err) - } - - err := hashNode(testCase.n, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n node - writes []writeCall - leafEncodingCopy bool - leafBufferLen int - leafBufferBytes []byte - parallel bool - wrappedErr error - errMessage string - }{ - "branch error": { - n: &branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode branch: " + - "cannot write stored encoding to buffer: " + - "test error", - }, - "branch success": { - n: &branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}}, - }, - }, - "leaf error": { - n: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode leaf: " + - "cannot write stored encoding to buffer: " + - "test error", - }, - "leaf success": { - n: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}}, - }, - leafEncodingCopy: true, - leafBufferLen: 3, - leafBufferBytes: []byte{1, 2, 3}, - }, - "nil node error": { - writes: []writeCall{ - {written: []byte{0}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode nil node: test error", - }, - "nil node success": { - writes: []writeCall{ - {written: []byte{0}}, - }, - }, - "unsupported node type": { - n: NewMocknode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "node type is not supported: *trie.Mocknode", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockbytesBuffer(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - if testCase.leafEncodingCopy { - previousCall = buffer.EXPECT().Len(). - Return(testCase.leafBufferLen). - After(previousCall) - buffer.EXPECT().Bytes(). - Return(testCase.leafBufferBytes). - After(previousCall) - } - - err := encodeNode(testCase.n, buffer, testCase.parallel) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeBranch(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - branch *branch - writes []writeCall - parallel bool - wrappedErr error - errMessage string - }{ - "clean branch with encoding": { - branch: &branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { // stored encoding - written: []byte{1, 2, 3}, - }, - }, - }, - "write error for clean branch with encoding": { - branch: &branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { // stored encoding - written: []byte{1, 2, 3}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write stored encoding to buffer: test error", - }, - "header encoding error": { - branch: &branch{ - key: make([]byte, 63+(1<<16)), - }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length greater than or equal to 2^16", - }, - "buffer write error for encoded header": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", - }, - "buffer write error for encoded key": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded key to buffer: test error", - }, - "buffer write error for children bitmap": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write children bitmap to buffer: test error", - }, - "buffer write error for value": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded value to buffer: test error", - }, - "buffer write error for children encoded sequentially": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // children - written: []byte{12, 65, 9, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot encode children of branch: " + - "cannot encode child at index 3: " + - "failed to write child to buffer: test error", - }, - "buffer write error for children encoded in parallel": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - err: errTest, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - parallel: true, - wrappedErr: errTest, - errMessage: "cannot encode children of branch: " + - "cannot write encoding of child at index 3: " + - "test error", - }, - "success with parallel children encoding": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - parallel: true, - }, - "success with sequential children encoding": { - branch: &branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeBranch(testCase.branch, buffer, testCase.parallel) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -//go:generate mockgen -destination=readwriter_mock_test.go -package $GOPACKAGE io ReadWriter - -func Test_encodeChildrenInParallel(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - children [16]node - writes []writeCall - wrappedErr error - errMessage string - }{ - "no children": {}, - "first child not nil": { - children: [16]node{ - &leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "last child not nil": { - children: [16]node{ - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - &leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "first two children not nil": { - children: [16]node{ - &leaf{key: []byte{1}}, - &leaf{key: []byte{2}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - { - written: []byte{12, 65, 2, 0}, - }, - }, - }, - "encoding error": { - children: [16]node{ - nil, nil, nil, nil, - nil, nil, nil, nil, - nil, nil, nil, - &leaf{ - key: []byte{1}, - }, - nil, nil, nil, nil, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoding of child at index 11: " + - "test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeChildrenInParallel(testCase.children, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeChildrenSequentially(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - children [16]node - writes []writeCall - wrappedErr error - errMessage string - }{ - "no children": {}, - "first child not nil": { - children: [16]node{ - &leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "last child not nil": { - children: [16]node{ - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - &leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "first two children not nil": { - children: [16]node{ - &leaf{key: []byte{1}}, - &leaf{key: []byte{2}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - { - written: []byte{12, 65, 2, 0}, - }, - }, - }, - "encoding error": { - children: [16]node{ - nil, nil, nil, nil, - nil, nil, nil, nil, - nil, nil, nil, - &leaf{ - key: []byte{1}, - }, - nil, nil, nil, nil, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot encode child at index 11: " + - "failed to write child to buffer: test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeChildrenSequentially(testCase.children, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer - -func Test_encodeChild(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - child node - writeCall bool - write writeCall - wrappedErr error - errMessage string - }{ - "nil node": {}, - "nil leaf": { - child: (*leaf)(nil), - }, - "nil branch": { - child: (*branch)(nil), - }, - "empty leaf child": { - child: &leaf{}, - writeCall: true, - write: writeCall{ - written: []byte{8, 64, 0}, - }, - }, - "empty branch child": { - child: &branch{}, - writeCall: true, - write: writeCall{ - written: []byte{12, 128, 0, 0}, - }, - }, - "buffer write error": { - child: &branch{}, - writeCall: true, - write: writeCall{ - written: []byte{12, 128, 0, 0}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "failed to write child to buffer: test error", - }, - "leaf child": { - child: &leaf{ - key: []byte{1}, - value: []byte{2}, - }, - writeCall: true, - write: writeCall{ - written: []byte{16, 65, 1, 4, 2}, - }, - }, - "branch child": { - child: &branch{ - key: []byte{1}, - value: []byte{2}, - children: [16]node{ - nil, nil, &leaf{ - key: []byte{5}, - value: []byte{6}, - }, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{44, 193, 1, 4, 0, 4, 2, 16, 65, 5, 4, 6}, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockWriter(ctrl) - - if testCase.writeCall { - buffer.EXPECT(). - Write(testCase.write.written). - Return(testCase.write.n, testCase.write.err) - } - - err := encodeChild(testCase.child, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeAndHash(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n node - b []byte - wrappedErr error - errMessage string - }{ - "node encoding error": { - n: NewMocknode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "cannot hash node: " + - "cannot encode node: " + - "node type is not supported: " + - "*trie.Mocknode", - }, - "leaf": { - n: &leaf{}, - b: []byte{0x8, 0x40, 0}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - b, err := encodeAndHash(testCase.n) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - - assert.Equal(t, testCase.b, b) - }) - } -} - -func Test_encodeLeaf(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - leaf *leaf - writes []writeCall - wrappedErr error - errMessage string - }{ - "clean leaf with encoding": { - leaf: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{1, 2, 3}, - }, - }, - }, - "write error for clean leaf with encoding": { - leaf: &leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{1, 2, 3}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write stored encoding to buffer: test error", - }, - "header encoding error": { - leaf: &leaf{ - key: make([]byte, 63+(1<<16)), - }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length greater than or equal to 2^16", - }, - "buffer write error for encoded header": { - leaf: &leaf{ - key: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{67}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", - }, - "buffer write error for encoded key": { - leaf: &leaf{ - key: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write LE key to buffer: test error", - }, - "buffer write error for encoded value": { - leaf: &leaf{ - key: []byte{1, 2, 3}, - value: []byte{4, 5, 6}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - }, - { - written: []byte{12, 4, 5, 6}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write scale encoded value to buffer: test error", - }, - "success": { - leaf: &leaf{ - key: []byte{1, 2, 3}, - value: []byte{4, 5, 6}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - }, - { - written: []byte{12, 4, 5, 6}, - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeLeaf(testCase.leaf, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index c15501a2d3..abf9ee9192 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -5,37 +5,46 @@ package trie import ( "bytes" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/record" ) +var _ recorder = (*record.Recorder)(nil) + +type recorder interface { + Record(hash, rawData []byte) +} + // findAndRecord search for a desired key recording all the nodes in the path including the desired node -func findAndRecord(t *Trie, key []byte, recorder *recorder) error { +func findAndRecord(t *Trie, key []byte, recorder recorder) error { return find(t.root, key, recorder) } -func find(parent node, key []byte, recorder *recorder) error { - enc, hash, err := parent.encodeAndHash() +func find(parent Node, key []byte, recorder recorder) error { + enc, hash, err := parent.EncodeAndHash() if err != nil { return err } - recorder.record(hash, enc) + recorder.Record(hash, enc) - b, ok := parent.(*branch) + b, ok := parent.(*node.Branch) if !ok { return nil } - length := lenCommonPrefix(b.key, key) + length := lenCommonPrefix(b.Key, key) // found the value at this node - if bytes.Equal(b.key, key) || len(key) == 0 { + if bytes.Equal(b.Key, key) || len(key) == 0 { return nil } // did not find value - if bytes.Equal(b.key[:length], key) && len(key) < len(b.key) { + if bytes.Equal(b.Key[:length], key) && len(key) < len(b.Key) { return nil } - return find(b.children[key[length]], key[length+1:], recorder) + return find(b.Children[key[length]], key[length+1:], recorder) } diff --git a/lib/trie/node.go b/lib/trie/node.go index 1def45a4c5..8ab60a5455 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -1,538 +1,9 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -//nolint:lll -// Modified Merkle-Patricia Trie -// See https://github.com/w3f/polkadot-spec/blob/master/runtime-environment-spec/polkadot_re_spec.pdf for the full specification. -// -// Note that for the following definitions, `|` denotes concatenation -// -// Branch encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 10 if branch w/o value, 11 if branch w/ value -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the branch's key -// `Value` is: Children Bitmap | SCALE Branch node Value | Hash(Enc(Child[i_1])) | Hash(Enc(Child[i_2])) | ... | Hash(Enc(Child[i_n])) -// -// Leaf encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 01 -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the leaf's key -// `Value` is the leaf's SCALE encoded value - package trie -import ( - "bytes" - "errors" - "fmt" - "io" - "sync" - - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" -) - -// node is the interface for trie methods -type node interface { - encodeAndHash() ([]byte, []byte, error) - decode(r io.Reader, h byte) error - isDirty() bool - setDirty(dirty bool) - setKey(key []byte) - String() string - setEncodingAndHash([]byte, []byte) - getHash() []byte - getGeneration() uint64 - setGeneration(uint64) - copy() node -} - -type ( - branch struct { - key []byte // partial key - children [16]node - value []byte - dirty bool - hash []byte - encoding []byte - generation uint64 - sync.RWMutex - } - leaf struct { - key []byte // partial key - value []byte - dirty bool - hash []byte - encoding []byte - encodingMu sync.RWMutex - generation uint64 - sync.RWMutex - } -) - -func (b *branch) setGeneration(generation uint64) { - b.generation = generation -} - -func (l *leaf) setGeneration(generation uint64) { - l.generation = generation -} - -func (b *branch) copy() node { - b.RLock() - defer b.RUnlock() - - cpy := &branch{ - key: make([]byte, len(b.key)), - children: b.children, // copy interface pointers - value: nil, - dirty: b.dirty, - hash: make([]byte, len(b.hash)), - encoding: make([]byte, len(b.encoding)), - generation: b.generation, - } - copy(cpy.key, b.key) - - // nil and []byte{} are encoded differently, watch out! - if b.value != nil { - cpy.value = make([]byte, len(b.value)) - copy(cpy.value, b.value) - } - - copy(cpy.hash, b.hash) - copy(cpy.encoding, b.encoding) - return cpy -} - -func (l *leaf) copy() node { - l.RLock() - defer l.RUnlock() - - l.encodingMu.RLock() - defer l.encodingMu.RUnlock() - - cpy := &leaf{ - key: make([]byte, len(l.key)), - value: make([]byte, len(l.value)), - dirty: l.dirty, - hash: make([]byte, len(l.hash)), - encoding: make([]byte, len(l.encoding)), - generation: l.generation, - } - copy(cpy.key, l.key) - copy(cpy.value, l.value) - copy(cpy.hash, l.hash) - copy(cpy.encoding, l.encoding) - return cpy -} - -func (b *branch) setEncodingAndHash(enc, hash []byte) { - b.encoding = enc - b.hash = hash -} - -func (l *leaf) setEncodingAndHash(enc, hash []byte) { - l.encodingMu.Lock() - l.encoding = enc - l.encodingMu.Unlock() - - l.hash = hash -} - -func (b *branch) getHash() []byte { - return b.hash -} - -func (b *branch) getGeneration() uint64 { - return b.generation -} - -func (l *leaf) getGeneration() uint64 { - return l.generation -} - -func (l *leaf) getHash() []byte { - return l.hash -} - -func (b *branch) String() string { - if len(b.value) > 1024 { - return fmt.Sprintf( - "branch key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", - b.key, b.childrenBitmap(), common.MustBlake2bHash(b.value), b.dirty) - } - return fmt.Sprintf("branch key=%x childrenBitmap=%16b value=%v dirty=%v", b.key, b.childrenBitmap(), b.value, b.dirty) -} - -func (l *leaf) String() string { - if len(l.value) > 1024 { - return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.key, common.MustBlake2bHash(l.value), l.dirty) - } - return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.key, l.value, l.dirty) -} - -func (b *branch) childrenBitmap() uint16 { - var bitmap uint16 - var i uint - for i = 0; i < 16; i++ { - if b.children[i] != nil { - bitmap = bitmap | 1<> 6 - if nodeType == 1 { - l := new(leaf) - err := l.decode(r, header) - return l, err - } else if nodeType == 2 || nodeType == 3 { - b := new(branch) - err := b.decode(r, header) - return b, err - } - - return nil, errors.New("cannot decode invalid encoding into node") -} - -// Decode decodes a byte array with the encoding specified at the top of this package into a branch node -// Note that since the encoded branch stores the hash of the children nodes, we aren't able to reconstruct the child -// nodes from the encoding. This function instead stubs where the children are known to be with an empty leaf. -func (b *branch) decode(r io.Reader, header byte) (err error) { - if header == 0 { - header, err = readByte(r) - if err != nil { - return err - } - } - - nodeType := header >> 6 - if nodeType != 2 && nodeType != 3 { - return fmt.Errorf("cannot decode node to branch") - } - - keyLen := header & 0x3f - b.key, err = decodeKey(r, keyLen) - if err != nil { - return err - } - - childrenBitmap := make([]byte, 2) - _, err = r.Read(childrenBitmap) - if err != nil { - return err - } - - sd := scale.NewDecoder(r) - - if nodeType == 3 { - var value []byte - // branch w/ value - err := sd.Decode(&value) - if err != nil { - return err - } - b.value = value - } - - for i := 0; i < 16; i++ { - if (childrenBitmap[i/8]>>(i%8))&1 == 1 { - var hash []byte - err := sd.Decode(&hash) - if err != nil { - return err - } - - b.children[i] = &leaf{ - hash: hash, - } - } - } - - b.dirty = true - - return nil -} - -// Decode decodes a byte array with the encoding specified at the top of this package into a leaf node -func (l *leaf) decode(r io.Reader, header byte) (err error) { - if header == 0 { - header, err = readByte(r) - if err != nil { - return err - } - } - - nodeType := header >> 6 - if nodeType != 1 { - return fmt.Errorf("cannot decode node to leaf") - } - - keyLen := header & 0x3f - l.key, err = decodeKey(r, keyLen) - if err != nil { - return err - } - - sd := scale.NewDecoder(r) - var value []byte - err = sd.Decode(&value) - if err != nil { - return err - } - - if len(value) > 0 { - l.value = value - } - - l.dirty = true - - return nil -} - -func (b *branch) header() ([]byte, error) { - var header byte - if b.value == nil { - header = 2 << 6 - } else { - header = 3 << 6 - } - var encodePkLen []byte - var err error - - if len(b.key) >= 63 { - header = header | 0x3f - encodePkLen, err = encodeExtraPartialKeyLength(len(b.key)) - if err != nil { - return nil, err - } - } else { - header = header | byte(len(b.key)) - } - - fullHeader := append([]byte{header}, encodePkLen...) - return fullHeader, nil -} - -func (l *leaf) header() ([]byte, error) { - var header byte = 1 << 6 - var encodePkLen []byte - var err error - - if len(l.key) >= 63 { - header = header | 0x3f - encodePkLen, err = encodeExtraPartialKeyLength(len(l.key)) - if err != nil { - return nil, err - } - } else { - header = header | byte(len(l.key)) - } - - fullHeader := append([]byte{header}, encodePkLen...) - return fullHeader, nil -} - -var ErrPartialKeyTooBig = errors.New("partial key length greater than or equal to 2^16") - -func encodeExtraPartialKeyLength(pkLen int) ([]byte, error) { - pkLen -= 63 - fullHeader := []byte{} - - if pkLen >= 1<<16 { - return nil, ErrPartialKeyTooBig - } - - for i := 0; i < 1<<16; i++ { - if pkLen < 255 { - fullHeader = append(fullHeader, byte(pkLen)) - break - } else { - fullHeader = append(fullHeader, byte(255)) - pkLen -= 255 - } - } - - return fullHeader, nil -} - -func decodeKey(r io.Reader, keyLen byte) ([]byte, error) { - var totalKeyLen = int(keyLen) - - if keyLen == 0x3f { - // partial key longer than 63, read next bytes for rest of pk len - for { - nextKeyLen, err := readByte(r) - if err != nil { - return nil, err - } - totalKeyLen += int(nextKeyLen) - - if nextKeyLen < 0xff { - break - } - - if totalKeyLen >= 1<<16 { - return nil, errors.New("partial key length greater than or equal to 2^16") - } - } - } - - if totalKeyLen != 0 { - key := make([]byte, totalKeyLen/2+totalKeyLen%2) - _, err := r.Read(key) - if err != nil { - return key, err - } - - return keyToNibbles(key)[totalKeyLen%2:], nil - } - - return []byte{}, nil -} +import "github.com/ChainSafe/gossamer/internal/trie/node" -func readByte(r io.Reader) (byte, error) { - buf := make([]byte, 1) - _, err := r.Read(buf) - if err != nil { - return 0, err - } - return buf[0], nil -} +// Node is a node in the trie and can be a leaf or a branch. +type Node node.Node diff --git a/lib/trie/node_mock_test.go b/lib/trie/node_mock_test.go deleted file mode 100644 index 0235377728..0000000000 --- a/lib/trie/node_mock_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: node.go - -// Package trie is a generated GoMock package. -package trie - -import ( - io "io" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// Mocknode is a mock of node interface. -type Mocknode struct { - ctrl *gomock.Controller - recorder *MocknodeMockRecorder -} - -// MocknodeMockRecorder is the mock recorder for Mocknode. -type MocknodeMockRecorder struct { - mock *Mocknode -} - -// NewMocknode creates a new mock instance. -func NewMocknode(ctrl *gomock.Controller) *Mocknode { - mock := &Mocknode{ctrl: ctrl} - mock.recorder = &MocknodeMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mocknode) EXPECT() *MocknodeMockRecorder { - return m.recorder -} - -// String mocks base method. -func (m *Mocknode) String() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "String") - ret0, _ := ret[0].(string) - return ret0 -} - -// String indicates an expected call of String. -func (mr *MocknodeMockRecorder) String() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*Mocknode)(nil).String)) -} - -// copy mocks base method. -func (m *Mocknode) copy() node { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "copy") - ret0, _ := ret[0].(node) - return ret0 -} - -// copy indicates an expected call of copy. -func (mr *MocknodeMockRecorder) copy() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "copy", reflect.TypeOf((*Mocknode)(nil).copy)) -} - -// decode mocks base method. -func (m *Mocknode) decode(r io.Reader, h byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "decode", r, h) - ret0, _ := ret[0].(error) - return ret0 -} - -// decode indicates an expected call of decode. -func (mr *MocknodeMockRecorder) decode(r, h interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "decode", reflect.TypeOf((*Mocknode)(nil).decode), r, h) -} - -// encodeAndHash mocks base method. -func (m *Mocknode) encodeAndHash() ([]byte, []byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "encodeAndHash") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].([]byte) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// encodeAndHash indicates an expected call of encodeAndHash. -func (mr *MocknodeMockRecorder) encodeAndHash() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "encodeAndHash", reflect.TypeOf((*Mocknode)(nil).encodeAndHash)) -} - -// getGeneration mocks base method. -func (m *Mocknode) getGeneration() uint64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getGeneration") - ret0, _ := ret[0].(uint64) - return ret0 -} - -// getGeneration indicates an expected call of getGeneration. -func (mr *MocknodeMockRecorder) getGeneration() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getGeneration", reflect.TypeOf((*Mocknode)(nil).getGeneration)) -} - -// getHash mocks base method. -func (m *Mocknode) getHash() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getHash") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// getHash indicates an expected call of getHash. -func (mr *MocknodeMockRecorder) getHash() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getHash", reflect.TypeOf((*Mocknode)(nil).getHash)) -} - -// isDirty mocks base method. -func (m *Mocknode) isDirty() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "isDirty") - ret0, _ := ret[0].(bool) - return ret0 -} - -// isDirty indicates an expected call of isDirty. -func (mr *MocknodeMockRecorder) isDirty() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isDirty", reflect.TypeOf((*Mocknode)(nil).isDirty)) -} - -// setDirty mocks base method. -func (m *Mocknode) setDirty(dirty bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setDirty", dirty) -} - -// setDirty indicates an expected call of setDirty. -func (mr *MocknodeMockRecorder) setDirty(dirty interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setDirty", reflect.TypeOf((*Mocknode)(nil).setDirty), dirty) -} - -// setEncodingAndHash mocks base method. -func (m *Mocknode) setEncodingAndHash(arg0, arg1 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setEncodingAndHash", arg0, arg1) -} - -// setEncodingAndHash indicates an expected call of setEncodingAndHash. -func (mr *MocknodeMockRecorder) setEncodingAndHash(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setEncodingAndHash", reflect.TypeOf((*Mocknode)(nil).setEncodingAndHash), arg0, arg1) -} - -// setGeneration mocks base method. -func (m *Mocknode) setGeneration(arg0 uint64) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setGeneration", arg0) -} - -// setGeneration indicates an expected call of setGeneration. -func (mr *MocknodeMockRecorder) setGeneration(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setGeneration", reflect.TypeOf((*Mocknode)(nil).setGeneration), arg0) -} - -// setKey mocks base method. -func (m *Mocknode) setKey(key []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "setKey", key) -} - -// setKey indicates an expected call of setKey. -func (mr *MocknodeMockRecorder) setKey(key interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setKey", reflect.TypeOf((*Mocknode)(nil).setKey), key) -} diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 04e69a796d..f667bb82b4 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -5,213 +5,11 @@ package trie import ( "bytes" - "math/rand" - "strconv" "testing" - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" - - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// byteArray makes byte array with length specified; used to test byte array encoding -func byteArray(length int) []byte { - b := make([]byte, length) - for i := 0; i < length; i++ { - b[i] = 0xf - } - return b -} - -func generateRand(size int) [][]byte { - rt := make([][]byte, size) - for i := range rt { - buf := make([]byte, rand.Intn(379)+1) - rand.Read(buf) - rt[i] = buf - } - return rt -} - -func TestChildrenBitmap(t *testing.T) { - b := &branch{children: [16]node{}} - res := b.childrenBitmap() - if res != 0 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) - } - - b.children[0] = &leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) - } - - b.children[4] = &leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1<<4+1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 17) - } - - b.children[15] = &leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1<<15+1<<4+1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 257) - } -} - -func TestBranchHeader(t *testing.T) { - tests := []struct { - br *branch - header []byte - }{ - {&branch{key: nil, children: [16]node{}, value: nil}, []byte{0x80}}, - {&branch{key: []byte{0x00}, children: [16]node{}, value: nil}, []byte{0x81}}, - {&branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, []byte{0x84}}, - - {&branch{key: nil, children: [16]node{}, value: []byte{0x01}}, []byte{0xc0}}, - {&branch{key: []byte{0x00}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc1}}, - {&branch{key: []byte{0x00, 0x00}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc2}}, - {&branch{key: []byte{0x00, 0x00, 0xf}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc3}}, - - {&branch{key: byteArray(62), children: [16]node{}, value: nil}, []byte{0xbe}}, - {&branch{key: byteArray(62), children: [16]node{}, value: []byte{0x00}}, []byte{0xfe}}, - {&branch{key: byteArray(63), children: [16]node{}, value: nil}, []byte{0xbf, 0}}, - {&branch{key: byteArray(64), children: [16]node{}, value: nil}, []byte{0xbf, 1}}, - {&branch{key: byteArray(64), children: [16]node{}, value: []byte{0x01}}, []byte{0xff, 1}}, - - {&branch{key: byteArray(317), children: [16]node{}, value: []byte{0x01}}, []byte{255, 254}}, - {&branch{key: byteArray(318), children: [16]node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, - {&branch{key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, - } - - for _, test := range tests { - test := test - res, err := test.br.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } else if !bytes.Equal(res, test.header) { - t.Errorf("Branch header fail case %v: got %x expected %x", test.br, res, test.header) - } - } -} - -func TestFailingPk(t *testing.T) { - tests := []struct { - br *branch - header []byte - }{ - {&branch{key: byteArray(2 << 16), children: [16]node{}, value: []byte{0x01}}, []byte{255, 254}}, - } - - for _, test := range tests { - _, err := test.br.header() - if err == nil { - t.Fatalf("should error when encoding node w pk length > 2^16") - } - } -} - -func TestLeafHeader(t *testing.T) { - tests := []struct { - br *leaf - header []byte - }{ - {&leaf{key: nil, value: nil}, []byte{0x40}}, - {&leaf{key: []byte{0x00}, value: nil}, []byte{0x41}}, - {&leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, []byte{0x44}}, - {&leaf{key: byteArray(62), value: nil}, []byte{0x7e}}, - {&leaf{key: byteArray(63), value: nil}, []byte{0x7f, 0}}, - {&leaf{key: byteArray(64), value: []byte{0x01}}, []byte{0x7f, 1}}, - - {&leaf{key: byteArray(318), value: []byte{0x01}}, []byte{0x7f, 0xff, 0}}, - {&leaf{key: byteArray(573), value: []byte{0x01}}, []byte{0x7f, 0xff, 0xff, 0}}, - } - - for i, test := range tests { - test := test - t.Run(strconv.Itoa(i), func(t *testing.T) { - res, err := test.br.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } else if !bytes.Equal(res, test.header) { - t.Errorf("Leaf header fail: got %x expected %x", res, test.header) - } - }) - } -} - -func TestBranchEncode(t *testing.T) { - randKeys := generateRand(101) - randVals := generateRand(101) - - for i, testKey := range randKeys { - b := &branch{key: testKey, children: [16]node{}, value: randVals[i]} - expected := bytes.NewBuffer(nil) - - header, err := b.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } - - expected.Write(header) - expected.Write(nibblesToKeyLE(b.key)) - expected.Write(common.Uint16ToBytes(b.childrenBitmap())) - - enc, err := scale.Marshal(b.value) - if err != nil { - t.Fatalf("Fail when encoding value with scale: %s", err) - } - - expected.Write(enc) - - for _, child := range b.children { - if child == nil { - continue - } - - err := hashNode(child, expected) - require.NoError(t, err) - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - err = encodeBranch(b, buffer, parallel) - require.NoError(t, err) - assert.Equal(t, expected.Bytes(), buffer.Bytes()) - } -} - -func TestLeafEncode(t *testing.T) { - randKeys := generateRand(100) - randVals := generateRand(100) - - for i, testKey := range randKeys { - l := &leaf{key: testKey, value: randVals[i]} - expected := []byte{} - - header, err := l.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } - expected = append(expected, header...) - expected = append(expected, nibblesToKeyLE(l.key)...) - - enc, err := scale.Marshal(l.value) - if err != nil { - t.Fatalf("Fail when encoding value with scale: %s", err) - } - - expected = append(expected, enc...) - - buffer := bytes.NewBuffer(nil) - err = encodeLeaf(l, buffer) - require.NoError(t, err) - assert.Equal(t, expected, buffer.Bytes()) - } -} - func TestEncodeRoot(t *testing.T) { trie := NewEmptyTrie() @@ -222,133 +20,12 @@ func TestEncodeRoot(t *testing.T) { val := trie.Get(test.key) if !bytes.Equal(val, test.value) { - t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, val) + t.Errorf("Fail to get Key %x with value %x: got %x", test.Key(), test.value, val) } buffer := bytes.NewBuffer(nil) - const parallel = false - err := encodeNode(trie.root, buffer, parallel) + err := trie.root.Encode(buffer) require.NoError(t, err) } } } - -func TestBranchDecode(t *testing.T) { - tests := []*branch{ - {key: []byte{}, children: [16]node{}, value: nil}, - {key: []byte{0x00}, children: [16]node{}, value: nil}, - {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, - {key: []byte{}, children: [16]node{}, value: []byte{0x01}}, - {key: []byte{}, children: [16]node{&leaf{}}, value: []byte{0x01}}, - {key: []byte{}, children: [16]node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, - { - key: []byte{}, - children: [16]node{ - &leaf{}, nil, &leaf{}, nil, - nil, nil, nil, nil, - nil, &leaf{}, nil, &leaf{}, - }, - value: []byte{0x01}, - }, - {key: byteArray(62), children: [16]node{}, value: nil}, - {key: byteArray(63), children: [16]node{}, value: nil}, - {key: byteArray(64), children: [16]node{}, value: nil}, - {key: byteArray(317), children: [16]node{}, value: []byte{0x01}}, - {key: byteArray(318), children: [16]node{}, value: []byte{0x01}}, - {key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - - for _, test := range tests { - err := encodeBranch(test, buffer, parallel) - require.NoError(t, err) - - res := new(branch) - err = res.decode(buffer, 0) - - require.NoError(t, err) - require.Equal(t, test.key, res.key) - require.Equal(t, test.childrenBitmap(), res.childrenBitmap()) - require.Equal(t, test.value, res.value) - } -} - -func TestLeafDecode(t *testing.T) { - tests := []*leaf{ - {key: []byte{}, value: nil, dirty: true}, - {key: []byte{0x01}, value: nil, dirty: true}, - {key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil, dirty: true}, - {key: byteArray(62), value: nil, dirty: true}, - {key: byteArray(63), value: nil, dirty: true}, - {key: byteArray(64), value: []byte{0x01}, dirty: true}, - {key: byteArray(318), value: []byte{0x01}, dirty: true}, - {key: byteArray(573), value: []byte{0x01}, dirty: true}, - } - - buffer := bytes.NewBuffer(nil) - - for _, test := range tests { - err := encodeLeaf(test, buffer) - require.NoError(t, err) - - res := new(leaf) - err = res.decode(buffer, 0) - require.NoError(t, err) - - res.hash = nil - test.encoding = nil - require.Equal(t, test, res) - } -} - -func TestDecode(t *testing.T) { - tests := []node{ - &branch{key: []byte{}, children: [16]node{}, value: nil}, - &branch{key: []byte{0x00}, children: [16]node{}, value: nil}, - &branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, - &branch{key: []byte{}, children: [16]node{}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]node{&leaf{}}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, - &branch{ - key: []byte{}, - children: [16]node{ - &leaf{}, nil, &leaf{}, nil, - nil, nil, nil, nil, - nil, &leaf{}, nil, &leaf{}}, - value: []byte{0x01}, - }, - &leaf{key: []byte{}, value: nil}, - &leaf{key: []byte{0x00}, value: nil}, - &leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, - &leaf{key: byteArray(62), value: nil}, - &leaf{key: byteArray(63), value: nil}, - &leaf{key: byteArray(64), value: []byte{0x01}}, - &leaf{key: byteArray(318), value: []byte{0x01}}, - &leaf{key: byteArray(573), value: []byte{0x01}}, - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - - for _, test := range tests { - err := encodeNode(test, buffer, parallel) - require.NoError(t, err) - - res, err := decode(buffer) - require.NoError(t, err) - - switch n := test.(type) { - case *branch: - require.Equal(t, n.key, res.(*branch).key) - require.Equal(t, n.childrenBitmap(), res.(*branch).childrenBitmap()) - require.Equal(t, n.value, res.(*branch).value) - case *leaf: - require.Equal(t, n.key, res.(*leaf).key) - require.Equal(t, n.value, res.(*leaf).value) - default: - t.Fatal("unexpected node") - } - } -} diff --git a/lib/trie/print.go b/lib/trie/print.go index ba72fde4a5..e39c8069f6 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -7,6 +7,8 @@ import ( "bytes" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/disiqueira/gotree" @@ -18,54 +20,53 @@ func (t *Trie) String() string { return "empty" } - tree := gotree.New(fmt.Sprintf("Trie root=0x%x", t.root.getHash())) + tree := gotree.New(fmt.Sprintf("Trie root=0x%x", t.root.GetHash())) t.string(tree, t.root, 0) return fmt.Sprintf("\n%s", tree.Print()) } -func (t *Trie) string(tree gotree.Tree, curr node, idx int) { +func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { switch c := curr.(type) { - case *branch: - buffer := encodingBufferPool.Get().(*bytes.Buffer) + case *node.Branch: + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - const parallel = false - _ = encodeBranch(c, buffer, parallel) - c.encoding = buffer.Bytes() + _ = c.Encode(buffer) + encoding := buffer.Bytes() var bstr string - if len(c.encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) + if len(encoding) > 1024 { + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", + idx, c, common.MustBlake2bHash(encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), encoding, c.GetGeneration()) } - encodingBufferPool.Put(buffer) + pools.EncodingBuffers.Put(buffer) sub := tree.Add(bstr) - for i, child := range c.children { + for i, child := range c.Children { if child != nil { t.string(sub, child, i) } } - case *leaf: - buffer := encodingBufferPool.Get().(*bytes.Buffer) + case *node.Leaf: + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - _ = encodeLeaf(c, buffer) + _ = c.Encode(buffer) - c.encodingMu.Lock() - defer c.encodingMu.Unlock() - c.encoding = buffer.Bytes() + encoding := buffer.Bytes() var bstr string - if len(c.encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) + if len(encoding) > 1024 { + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", + idx, c.String(), common.MustBlake2bHash(encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), encoding, c.GetGeneration()) } - encodingBufferPool.Put(buffer) + pools.EncodingBuffers.Put(buffer) tree.Add(bstr) default: diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 2b77f846f5..2d8444d2db 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -10,6 +10,8 @@ import ( "fmt" "github.com/ChainSafe/chaindb" + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/record" "github.com/ChainSafe/gossamer/lib/common" ) @@ -40,19 +42,18 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e } for _, k := range keys { - nk := keyToNibbles(k) + nk := codec.KeyLEToNibbles(k) - recorder := new(recorder) + recorder := record.NewRecorder() err := findAndRecord(proofTrie, nk, recorder) if err != nil { return nil, err } - for !recorder.isEmpty() { - recNode := recorder.next() - nodeHashHex := common.BytesToHex(recNode.hash) + for _, recNode := range recorder.GetNodes() { + nodeHashHex := common.BytesToHex(recNode.Hash) if _, ok := trackedProofs[nodeHashHex]; !ok { - trackedProofs[nodeHashHex] = recNode.rawData + trackedProofs[nodeHashHex] = recNode.RawData } } } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 2e472e2fdc..7c190d1c3c 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -69,7 +69,7 @@ func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][ err = trie.Store(memdb) require.NoError(t, err) - root := trie.root.getHash() + root := trie.root.GetHash() proof, err := GenerateProof(root, keys, memdb) require.NoError(t, err) diff --git a/lib/trie/readwriter_mock_test.go b/lib/trie/readwriter_mock_test.go deleted file mode 100644 index 6d1affa288..0000000000 --- a/lib/trie/readwriter_mock_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: io (interfaces: ReadWriter) - -// Package trie is a generated GoMock package. -package trie - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockReadWriter is a mock of ReadWriter interface. -type MockReadWriter struct { - ctrl *gomock.Controller - recorder *MockReadWriterMockRecorder -} - -// MockReadWriterMockRecorder is the mock recorder for MockReadWriter. -type MockReadWriterMockRecorder struct { - mock *MockReadWriter -} - -// NewMockReadWriter creates a new mock instance. -func NewMockReadWriter(ctrl *gomock.Controller) *MockReadWriter { - mock := &MockReadWriter{ctrl: ctrl} - mock.recorder = &MockReadWriterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockReadWriter) EXPECT() *MockReadWriterMockRecorder { - return m.recorder -} - -// Read mocks base method. -func (m *MockReadWriter) Read(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockReadWriterMockRecorder) Read(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReadWriter)(nil).Read), arg0) -} - -// Write mocks base method. -func (m *MockReadWriter) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockReadWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockReadWriter)(nil).Write), arg0) -} diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go deleted file mode 100644 index 6db2a841d0..0000000000 --- a/lib/trie/recorder.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -// nodeRecord represets a record of a visited node -type nodeRecord struct { - rawData []byte - hash []byte -} - -// recorder keeps the list of nodes find by Lookup.Find -type recorder []nodeRecord - -// record insert a node inside the recorded list -func (r *recorder) record(h, rd []byte) { - *r = append(*r, nodeRecord{rawData: rd, hash: h}) -} - -// next returns the current item the cursor is on and increment the cursor by 1 -func (r *recorder) next() *nodeRecord { - if !r.isEmpty() { - n := (*r)[0] - *r = (*r)[1:] - return &n - } - - return nil -} - -// isEmpty returns bool if there is data inside the slice -func (r *recorder) isEmpty() bool { - return len(*r) <= 0 -} diff --git a/lib/trie/trie.go b/lib/trie/trie.go index fe422a2e74..c8a33d7167 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -7,6 +7,9 @@ import ( "bytes" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" ) @@ -18,10 +21,9 @@ var EmptyHash, _ = NewEmptyTrie().Hash() // Use NewTrie to create a trie that sits on top of a database. type Trie struct { generation uint64 - root node + root Node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash - parallel bool } // NewEmptyTrie creates a trie with a nil root @@ -30,13 +32,12 @@ func NewEmptyTrie() *Trie { } // NewTrie creates a trie with an existing root node -func NewTrie(root node) *Trie { +func NewTrie(root Node) *Trie { return &Trie{ root: root, childTries: make(map[common.Hash]*Trie), generation: 0, // Initially zero but increases after every snapshot. deletedKeys: make([]common.Hash, 0), - parallel: true, } } @@ -48,7 +49,6 @@ func (t *Trie) Snapshot() *Trie { generation: c.generation + 1, root: c.root, deletedKeys: make([]common.Hash, 0), - parallel: c.parallel, } } @@ -57,25 +57,24 @@ func (t *Trie) Snapshot() *Trie { root: t.root, childTries: children, deletedKeys: make([]common.Hash, 0), - parallel: t.parallel, } return newTrie } -func (t *Trie) maybeUpdateGeneration(n node) node { +func (t *Trie) maybeUpdateGeneration(n Node) Node { if n == nil { return nil } // Make a copy if the generation is updated. - if n.getGeneration() < t.generation { + if n.GetGeneration() < t.generation { // Insert a new node in the current generation. - newNode := n.copy() - newNode.setGeneration(t.generation) + newNode := n.Copy() + newNode.SetGeneration(t.generation) // Hash of old nodes should already be computed since it belongs to older generation. - oldNodeHash := n.getHash() + oldNodeHash := n.GetHash() if len(oldNodeHash) > 0 { hash := common.BytesToHash(oldNodeHash) t.deletedKeys = append(t.deletedKeys, hash) @@ -102,13 +101,20 @@ func (t *Trie) DeepCopy() (*Trie, error) { } // RootNode returns the root of the trie -func (t *Trie) RootNode() node { +func (t *Trie) RootNode() Node { return t.root } // encodeRoot returns the encoded root of the trie func (t *Trie) encodeRoot(buffer *bytes.Buffer) (err error) { - return encodeNode(t.RootNode(), buffer, t.parallel) + if t.root == nil { + _, err = buffer.Write([]byte{0}) + if err != nil { + return fmt.Errorf("cannot write nil root node to buffer: %w", err) + } + return nil + } + return t.root.Encode(buffer) } // MustHash returns the hashed root of the trie. It panics if it fails to hash the root node. @@ -123,9 +129,9 @@ func (t *Trie) MustHash() common.Hash { // Hash returns the hashed root of the trie func (t *Trie) Hash() (common.Hash, error) { - buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - defer encodingBufferPool.Put(buffer) + defer pools.EncodingBuffers.Put(buffer) err := t.encodeRoot(buffer) if err != nil { @@ -140,17 +146,17 @@ func (t *Trie) Entries() map[string][]byte { return t.entries(t.root, nil, make(map[string][]byte)) } -func (t *Trie) entries(current node, prefix []byte, kv map[string][]byte) map[string][]byte { +func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { - case *branch: - if c.value != nil { - kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value + case *node.Branch: + if c.Value != nil { + kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value } - for i, child := range c.children { - t.entries(child, append(prefix, append(c.key, byte(i))...), kv) + for i, child := range c.Children { + t.entries(child, append(prefix, append(c.Key, byte(i))...), kv) } - case *leaf: - kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value + case *node.Leaf: + kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value return kv } @@ -159,20 +165,20 @@ func (t *Trie) entries(current node, prefix []byte, kv map[string][]byte) map[st // NextKey returns the next key in the trie in lexicographic order. It returns nil if there is no next key func (t *Trie) NextKey(key []byte) []byte { - k := keyToNibbles(key) + k := codec.KeyLEToNibbles(key) next := t.nextKey(t.root, nil, k) if next == nil { return nil } - return nibblesToKeyLE(next) + return codec.NibblesToKeyLE(next) } -func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { +func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { switch c := curr.(type) { - case *branch: - fullKey := append(prefix, c.key...) + case *node.Branch: + fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil @@ -190,11 +196,11 @@ func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { // return key of first child, or key of this branch, // if it's a branch with value. if (cmp == 0 && len(key) == len(fullKey)) || cmp == 1 { - if c.value != nil && bytes.Compare(fullKey, key) > 0 { + if c.Value != nil && bytes.Compare(fullKey, key) > 0 { return fullKey } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } @@ -209,7 +215,7 @@ func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { // node key isn't greater than the arg key, continue to iterate if cmp < 1 && len(key) > len(fullKey) { idx := key[len(fullKey)] - for i, child := range c.children[idx:] { + for i, child := range c.Children[idx:] { if child == nil { continue } @@ -220,8 +226,8 @@ func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { } } } - case *leaf: - fullKey := append(prefix, c.key...) + case *node.Leaf: + fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil @@ -236,7 +242,7 @@ func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { } if cmp == 1 { - return append(prefix, c.key...) + return append(prefix, c.Key...) } case nil: return nil @@ -250,69 +256,71 @@ func (t *Trie) Put(key, value []byte) { } func (t *Trie) tryPut(key, value []byte) { - k := keyToNibbles(key) + k := codec.KeyLEToNibbles(key) - t.root = t.insert(t.root, k, &leaf{key: nil, value: value, dirty: true, generation: t.generation}) + t.root = t.insert(t.root, k, node.NewLeaf(nil, value, true, t.generation)) } // insert attempts to insert a key with value into the trie -func (t *Trie) insert(parent node, key []byte, value node) node { +func (t *Trie) insert(parent Node, key []byte, value Node) Node { switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch: + case *node.Branch: n := t.updateBranch(p, key, value) - if p != nil && n != nil && n.isDirty() { - p.setDirty(true) + if p != nil && n != nil && n.IsDirty() { + p.SetDirty(true) } return n case nil: - value.setKey(key) + value.SetKey(key) return value - case *leaf: + case *node.Leaf: // if a value already exists in the trie at this key, overwrite it with the new value // if the values are the same, don't mark node dirty - if p.value != nil && bytes.Equal(p.key, key) { - if !bytes.Equal(value.(*leaf).value, p.value) { - p.value = value.(*leaf).value - p.dirty = true + if p.Value != nil && bytes.Equal(p.Key, key) { + if !bytes.Equal(value.(*node.Leaf).Value, p.Value) { + p.Value = value.(*node.Leaf).Value + p.SetDirty(true) } return p } - length := lenCommonPrefix(key, p.key) + length := lenCommonPrefix(key, p.Key) // need to convert this leaf into a branch - br := &branch{key: key[:length], dirty: true, generation: t.generation} - parentKey := p.key + var newBranchValue []byte + const newBranchDirty = true + br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) + parentKey := p.Key // value goes at this branch if len(key) == length { - br.value = value.(*leaf).value - br.setDirty(true) + br.Value = value.(*node.Leaf).Value + br.SetDirty(true) // if we are not replacing previous leaf, then add it as a child to the new branch if len(parentKey) > len(key) { - p.key = p.key[length+1:] - br.children[parentKey[length]] = p - p.setDirty(true) + p.Key = p.Key[length+1:] + br.Children[parentKey[length]] = p + p.SetDirty(true) } return br } - value.setKey(key[length+1:]) + value.SetKey(key[length+1:]) - if length == len(p.key) { + if length == len(p.Key) { // if leaf's key is covered by this branch, then make the leaf's // value the value at this branch - br.value = p.value - br.children[key[length]] = value + br.Value = p.Value + br.Children[key[length]] = value } else { // otherwise, make the leaf a child of the branch and update its partial key - p.key = p.key[length+1:] - p.setDirty(true) - br.children[parentKey[length]] = p - br.children[key[length]] = value + p.Key = p.Key[length+1:] + p.SetDirty(true) + br.Children[parentKey[length]] = p + br.Children[key[length]] = value } return br @@ -324,35 +332,35 @@ func (t *Trie) insert(parent node, key []byte, value node) node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { - length := lenCommonPrefix(key, p.key) +func (t *Trie) updateBranch(p *node.Branch, key []byte, value Node) (n Node) { + length := lenCommonPrefix(key, p.Key) // whole parent key matches - if length == len(p.key) { + if length == len(p.Key) { // if node has same key as this branch, then update the value at this branch - if bytes.Equal(key, p.key) { - p.setDirty(true) + if bytes.Equal(key, p.Key) { + p.SetDirty(true) switch v := value.(type) { - case *branch: - p.value = v.value - case *leaf: - p.value = v.value + case *node.Branch: + p.Value = v.Value + case *node.Leaf: + p.Value = v.Value } return p } - switch c := p.children[key[length]].(type) { - case *branch, *leaf: + switch c := p.Children[key[length]].(type) { + case *node.Branch, *node.Leaf: n = t.insert(c, key[length+1:], value) - p.children[key[length]] = n - n.setDirty(true) - p.setDirty(true) + p.Children[key[length]] = n + n.SetDirty(true) + p.SetDirty(true) return p case nil: // otherwise, add node as child of this branch - value.(*leaf).key = key[length+1:] - p.children[key[length]] = value - p.setDirty(true) + value.(*node.Leaf).Key = key[length+1:] + p.Children[key[length]] = value + p.SetDirty(true) return p } @@ -361,18 +369,20 @@ func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - br := &branch{key: key[:length], dirty: true, generation: t.generation} + var newBranchValue []byte + const newBranchDirty = true + br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) - parentIndex := p.key[length] - br.children[parentIndex] = t.insert(nil, p.key[length+1:], p) + parentIndex := p.Key[length] + br.Children[parentIndex] = t.insert(nil, p.Key[length+1:], p) if len(key) <= length { - br.value = value.(*leaf).value + br.Value = value.(*node.Leaf).Value } else { - br.children[key[length]] = t.insert(nil, key[length+1:], value) + br.Children[key[length]] = t.insert(nil, key[length+1:], value) } - br.setDirty(true) + br.SetDirty(true) return br } @@ -397,7 +407,7 @@ func (t *Trie) LoadFromMap(data map[string]string) error { func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { var p []byte if len(prefix) != 0 { - p = keyToNibbles(prefix) + p = codec.KeyLEToNibbles(prefix) if p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -406,28 +416,28 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { return t.getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) } -func (t *Trie) getKeysWithPrefix(parent node, prefix, key []byte, keys [][]byte) [][]byte { +func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch: - length := lenCommonPrefix(p.key, key) + case *node.Branch: + length := lenCommonPrefix(p.Key, key) - if bytes.Equal(p.key[:length], key) || len(key) == 0 { + if bytes.Equal(p.Key[:length], key) || len(key) == 0 { // node has prefix, add to list and add all descendant nodes to list keys = t.addAllKeys(p, prefix, keys) return keys } - if len(key) <= len(p.key) || length < len(p.key) { + if len(key) <= len(p.Key) || length < len(p.Key) { // no prefixed keys to be found here, return return keys } - key = key[len(p.key):] - keys = t.getKeysWithPrefix(p.children[key[0]], append(append(prefix, p.key...), key[0]), key[1:], keys) - case *leaf: - length := lenCommonPrefix(p.key, key) - if bytes.Equal(p.key[:length], key) || len(key) == 0 { - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + key = key[len(p.Key):] + keys = t.getKeysWithPrefix(p.Children[key[0]], append(append(prefix, p.Key...), key[0]), key[1:], keys) + case *node.Leaf: + length := lenCommonPrefix(p.Key, key) + if bytes.Equal(p.Key[:length], key) || len(key) == 0 { + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) } case nil: return keys @@ -437,18 +447,18 @@ func (t *Trie) getKeysWithPrefix(parent node, prefix, key []byte, keys [][]byte) // addAllKeys appends all keys that are descendants of the parent node to a slice of keys // it uses the prefix to determine the entire key -func (t *Trie) addAllKeys(parent node, prefix []byte, keys [][]byte) [][]byte { +func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch: - if p.value != nil { - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + case *node.Branch: + if p.Value != nil { + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) } - for i, child := range p.children { - keys = t.addAllKeys(child, append(append(prefix, p.key...), byte(i)), keys) + for i, child := range p.Children { + keys = t.addAllKeys(child, append(append(prefix, p.Key...), byte(i)), keys) } - case *leaf: - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + case *node.Leaf: + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) case nil: return keys } @@ -463,36 +473,36 @@ func (t *Trie) Get(key []byte) []byte { return nil } - return l.value + return l.Value } -func (t *Trie) tryGet(key []byte) *leaf { - k := keyToNibbles(key) +func (t *Trie) tryGet(key []byte) *node.Leaf { + k := codec.KeyLEToNibbles(key) return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent node, key []byte) *leaf { +func (t *Trie) retrieve(parent Node, key []byte) *node.Leaf { var ( - value *leaf + value *node.Leaf ) switch p := parent.(type) { - case *branch: - length := lenCommonPrefix(p.key, key) + case *node.Branch: + length := lenCommonPrefix(p.Key, key) // found the value at this node - if bytes.Equal(p.key, key) || len(key) == 0 { - return &leaf{key: p.key, value: p.value, dirty: false} + if bytes.Equal(p.Key, key) || len(key) == 0 { + return node.NewLeaf(p.Key, p.Value, false, 0) } // did not find value - if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { + if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { return nil } - value = t.retrieve(p.children[key[length]], key[length+1:]) - case *leaf: - if bytes.Equal(p.key, key) { + value = t.retrieve(p.Children[key[length]], key[length+1:]) + case *node.Leaf: + if bytes.Equal(p.Key, key) { value = p } case nil: @@ -507,7 +517,7 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { return 0, false } - p := keyToNibbles(prefix) + p := codec.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -520,12 +530,12 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { // clearPrefixLimit deletes the keys having the prefix till limit reached and returns updated trie root node, // true if any node in the trie got updated, and next bool returns true if there is no keys left with prefix. -func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bool, bool) { +func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bool, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch: - length := lenCommonPrefix(c.key, prefix) + case *node.Branch: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { n, _ := t.deleteNodes(c, []byte{}, limit) if n == nil { @@ -534,36 +544,36 @@ func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bo return n, true, false } - if len(prefix) == len(c.key)+1 && length == len(prefix)-1 { - i := prefix[len(c.key)] - c.children[i], _ = t.deleteNodes(c.children[i], []byte{}, limit) + if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { + i := prefix[len(c.Key)] + c.Children[i], _ = t.deleteNodes(c.Children[i], []byte{}, limit) - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) - if c.children[i] == nil { + if c.Children[i] == nil { return curr, true, true } return c, true, false } - if len(prefix) <= len(c.key) || length < len(c.key) { + if len(prefix) <= len(c.Key) || length < len(c.Key) { // this node doesn't have the prefix, return return c, false, true } - i := prefix[len(c.key)] + i := prefix[len(c.Key)] var wasUpdated, allDeleted bool - c.children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.children[i], prefix[len(c.key)+1:], limit) + c.Children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.Children[i], prefix[len(c.Key)+1:], limit) if wasUpdated { - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) } - return curr, curr.isDirty(), allDeleted - case *leaf: - length := lenCommonPrefix(c.key, prefix) + return curr, curr.IsDirty(), allDeleted + case *node.Leaf: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { *limit-- return nil, true, true @@ -578,35 +588,35 @@ func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bo return nil, false, true } -func (t *Trie) deleteNodes(cn node, prefix []byte, limit *uint32) (node, bool) { +func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *leaf: + case *node.Leaf: if *limit == 0 { return c, false } *limit-- return nil, true - case *branch: - if len(c.key) != 0 { - prefix = append(prefix, c.key...) + case *node.Branch: + if len(c.Key) != 0 { + prefix = append(prefix, c.Key...) } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } var isDel bool - if c.children[i], isDel = t.deleteNodes(child, prefix, limit); !isDel { + if c.Children[i], isDel = t.deleteNodes(child, prefix, limit); !isDel { continue } - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) - isAllNil := c.numChildren() == 0 - if isAllNil && c.value == nil { + isAllNil := c.NumChildren() == 0 + if isAllNil && c.Value == nil { curr = nil } @@ -620,7 +630,7 @@ func (t *Trie) deleteNodes(cn node, prefix []byte, limit *uint32) (node, bool) { } // Delete the current node as well - if c.value != nil { + if c.Value != nil { *limit-- } return nil, true @@ -636,7 +646,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { return } - p := keyToNibbles(prefix) + p := codec.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -644,11 +654,11 @@ func (t *Trie) ClearPrefix(prefix []byte) { t.root, _ = t.clearPrefix(t.root, p) } -func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { +func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch: - length := lenCommonPrefix(c.key, prefix) + case *node.Branch: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { // found prefix at this branch, delete it @@ -657,32 +667,32 @@ func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { // Store the current node and return it, if the trie is not updated. - if len(prefix) == len(c.key)+1 && length == len(prefix)-1 { + if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { // found prefix at child index, delete child - i := prefix[len(c.key)] - c.children[i] = nil - c.setDirty(true) + i := prefix[len(c.Key)] + c.Children[i] = nil + c.SetDirty(true) curr = handleDeletion(c, prefix) return curr, true } - if len(prefix) <= len(c.key) || length < len(c.key) { + if len(prefix) <= len(c.Key) || length < len(c.Key) { // this node doesn't have the prefix, return return c, false } var wasUpdated bool - i := prefix[len(c.key)] + i := prefix[len(c.Key)] - c.children[i], wasUpdated = t.clearPrefix(c.children[i], prefix[len(c.key)+1:]) + c.Children[i], wasUpdated = t.clearPrefix(c.Children[i], prefix[len(c.Key)+1:]) if wasUpdated { - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) } - return curr, curr.isDirty() - case *leaf: - length := lenCommonPrefix(c.key, prefix) + return curr, curr.IsDirty() + case *node.Leaf: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { return nil, true } @@ -696,35 +706,35 @@ func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { - k := keyToNibbles(key) + k := codec.KeyLEToNibbles(key) t.root, _ = t.delete(t.root, k) } -func (t *Trie) delete(parent node, key []byte) (node, bool) { +func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch: + case *node.Branch: - length := lenCommonPrefix(p.key, key) - if bytes.Equal(p.key, key) || len(key) == 0 { + length := lenCommonPrefix(p.Key, key) + if bytes.Equal(p.Key, key) || len(key) == 0 { // found the value at this node - p.value = nil - p.setDirty(true) + p.Value = nil + p.SetDirty(true) return handleDeletion(p, key), true } - n, del := t.delete(p.children[key[length]], key[length+1:]) + n, del := t.delete(p.Children[key[length]], key[length+1:]) if !del { // If nothing was deleted then don't copy the path. return p, false } - p.children[key[length]] = n - p.setDirty(true) + p.Children[key[length]] = n + p.SetDirty(true) n = handleDeletion(p, key) return n, true - case *leaf: - if bytes.Equal(key, p.key) || len(key) == 0 { + case *node.Leaf: + if bytes.Equal(key, p.Key) || len(key) == 0 { // Key exists. Delete it. return nil, true } @@ -740,15 +750,15 @@ func (t *Trie) delete(parent node, key []byte) (node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *branch, key []byte) node { - var n node = p - length := lenCommonPrefix(p.key, key) - bitmap := p.childrenBitmap() +func handleDeletion(p *node.Branch, key []byte) Node { + var n Node = p + length := lenCommonPrefix(p.Key, key) + bitmap := p.ChildrenBitmap() // if branch has no children, just a value, turn it into a leaf - if bitmap == 0 && p.value != nil { - n = &leaf{key: key[:length], value: p.value, dirty: true} - } else if p.numChildren() == 1 && p.value == nil { + if bitmap == 0 && p.Value != nil { + n = node.NewLeaf(key[:length], p.Value, true, 0) + } else if p.NumChildren() == 1 && p.Value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child var i int @@ -759,27 +769,27 @@ func handleDeletion(p *branch, key []byte) node { } } - child := p.children[i] + child := p.Children[i] switch c := child.(type) { - case *leaf: - n = &leaf{key: append(append(p.key, []byte{byte(i)}...), c.key...), value: c.value} - case *branch: - br := new(branch) - br.key = append(p.key, append([]byte{byte(i)}, c.key...)...) + case *node.Leaf: + n = &node.Leaf{Key: append(append(p.Key, []byte{byte(i)}...), c.Key...), Value: c.Value} + case *node.Branch: + br := new(node.Branch) + br.Key = append(p.Key, append([]byte{byte(i)}, c.Key...)...) // adopt the grandchildren - for i, grandchild := range c.children { + for i, grandchild := range c.Children { if grandchild != nil { - br.children[i] = grandchild + br.Children[i] = grandchild } } - br.value = c.value + br.Value = c.Value n = br default: // do nothing } - n.setDirty(true) + n.SetDirty(true) } return n diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index f6ae3ff779..cc19116a50 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -20,6 +20,8 @@ import ( "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" + "github.com/ChainSafe/gossamer/internal/trie/codec" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" ) @@ -68,7 +70,7 @@ func TestNewEmptyTrie(t *testing.T) { } func TestNewTrie(t *testing.T) { - trie := NewTrie(&leaf{key: []byte{0}, value: []byte{17}}) + trie := NewTrie(&node.Leaf{Key: []byte{0}, Value: []byte{17}}) if trie == nil { t.Error("did not initialise trie") } @@ -160,10 +162,10 @@ func runTests(t *testing.T, trie *Trie, tests []Test) { leaf := trie.tryGet(test.key) if leaf == nil { t.Errorf("Fail to get key %x: nil leaf", test.key) - } else if !bytes.Equal(leaf.value, test.value) { - t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, leaf.value) - } else if !bytes.Equal(leaf.key, test.pk) { - t.Errorf("Fail to get correct partial key %x with key %x: got %x", test.pk, test.key, leaf.key) + } else if !bytes.Equal(leaf.Value, test.value) { + t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, leaf.Value) + } else if !bytes.Equal(leaf.Key, test.pk) { + t.Errorf("Fail to get correct partial key %x with key %x: got %x", test.pk, test.key, leaf.Key) } } }) @@ -515,7 +517,7 @@ func TestTrieDiff(t *testing.T) { } dbTrie := NewEmptyTrie() - err = dbTrie.Load(storageDB, common.BytesToHash(newTrie.root.getHash())) + err = dbTrie.Load(storageDB, common.BytesToHash(newTrie.root.GetHash())) require.NoError(t, err) } @@ -873,7 +875,7 @@ func TestClearPrefix(t *testing.T) { require.Equal(t, dcTrieHash, ssTrieHash) ssTrie.ClearPrefix(prefix) - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -881,7 +883,7 @@ func TestClearPrefix(t *testing.T) { for _, test := range tests { res := ssTrie.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { require.Nil(t, res) @@ -942,7 +944,14 @@ func TestClearPrefix_Small(t *testing.T) { } ssTrie.ClearPrefix([]byte("noo")) - require.Equal(t, ssTrie.root, &leaf{key: keyToNibbles([]byte("other")), value: []byte("other"), dirty: true}) + + expectedRoot := &node.Leaf{ + Key: codec.KeyLEToNibbles([]byte("other")), + Value: []byte("other"), + } + expectedRoot.SetDirty(true) + + require.Equal(t, expectedRoot, ssTrie.root) // Get the updated root hash of all tries. tHash, err = trie.Hash() @@ -1125,35 +1134,12 @@ func Benchmark_Trie_Hash(b *testing.B) { trie.Put(test.key, test.value) } - trieTwo, err := trie.DeepCopy() - require.NoError(b, err) - - b.Run("Sequential hash", func(b *testing.B) { - trie.parallel = false - - b.StartTimer() - _, err := trie.Hash() - b.StopTimer() - - require.NoError(b, err) + b.StartTimer() + _, err := trie.Hash() + b.StopTimer() - printMemUsage() - }) - - b.Run("Parallel hash", func(b *testing.B) { - trieTwo.parallel = true - - b.StartTimer() - _, err := trieTwo.Hash() - b.StopTimer() - - require.NoError(b, err) - - printMemUsage() - }) -} + require.NoError(b, err) -func printMemUsage() { var m runtime.MemStats runtime.ReadMemStats(&m) // For info on each, see: https://golang.org/pkg/runtime/#MemStats @@ -1310,7 +1296,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { } testFn := func(testCase []Test, prefix []byte) { - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1329,7 +1315,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { for _, test := range testCase { val := trieClearPrefix.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { @@ -1418,7 +1404,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, testCase := range cases { for _, prefix := range prefixes { - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1458,7 +1444,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, test := range testCase { val := ssTrie.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) {