From 93ee8b647f01dc915e0054494b7b4fc3af6e4d98 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 9 Dec 2021 22:55:23 +0000 Subject: [PATCH] `node.Decode` function --- internal/trie/node/decode.go | 43 +++++++++- internal/trie/node/decode_test.go | 90 +++++++++++++++++++- internal/trie/node/encode_decode_test.go | 2 +- lib/trie/database.go | 10 +-- lib/trie/decode.go | 48 ----------- lib/trie/decode_test.go | 103 ----------------------- 6 files changed, 131 insertions(+), 165 deletions(-) delete mode 100644 lib/trie/decode.go delete mode 100644 lib/trie/decode_test.go diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 733ec85aee..05be18c2b4 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -4,15 +4,18 @@ package node import ( + "bytes" "errors" "fmt" "io" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) var ( ErrReadHeaderByte = errors.New("cannot read header byte") + ErrUnknownNodeType = errors.New("unknown node type") ErrNodeTypeIsNotABranch = errors.New("node type is not a branch") ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") ErrDecodeValue = errors.New("cannot decode value") @@ -20,12 +23,44 @@ var ( ErrDecodeChildHash = errors.New("cannot decode child hash") ) -// DecodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +// Decode decodes a node from a reader. +// For branch decoding, see the comments on decodeBranch. +// For leaf decoding, see the comments on decodeLeaf. +func Decode(reader io.Reader) (n Node, err error) { + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() + _, err = reader.Read(oneByteBuf) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + } + header := oneByteBuf[0] + + nodeType := header >> 6 + switch nodeType { + case LeafType: + n, err = decodeLeaf(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case BranchType, BranchWithValueType: + n, err = decodeBranch(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) + } +} + +// decodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. // Note that since the encoded branch stores the hash of the children nodes, we are not // reconstructing the child nodes from the encoding. This function instead stubs where the // children are known to be with an empty leaf. The children nodes hashes are then used to // find other values using the persistent database. -func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { +func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { nodeType := header >> 6 if nodeType != 2 && nodeType != 3 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) @@ -78,8 +113,8 @@ func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { return branch, nil } -// DecodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. -func DecodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { +// decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { nodeType := header >> 6 if nodeType != 1 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index c6840f0a74..b3b2d91ef8 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -31,7 +31,89 @@ func concatByteSlices(slices [][]byte) (concatenated []byte) { return concatenated } -func Test_DecodeBranch(t *testing.T) { +func Test_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + n Node + errWrapped error + errMessage string + }{ + "no data": { + reader: bytes.NewReader(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "unknown node type": { + reader: bytes.NewReader([]byte{0}), + errWrapped: ErrUnknownNodeType, + errMessage: "unknown node type: 0", + }, + "leaf decoding error": { + reader: bytes.NewReader([]byte{ + 65, // node type 1 and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", + }, + "leaf success": { + reader: bytes.NewReader( + append( + []byte{ + 65, // node type 1 and key length 1 + 9, // key data + }, + scaleEncodeBytes(t, 1, 2, 3)..., + ), + ), + n: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3}, + Dirty: true, + }, + }, + "branch decoding error": { + reader: bytes.NewReader([]byte{ + 129, // node type 2 and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", + }, + "branch success": { + reader: bytes.NewReader( + []byte{ + 129, // node type 2 and key length 1 + 9, // key data + 0, 0, // no children bitmap + }, + ), + n: &Branch{ + Key: []byte{9}, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + n, err := Decode(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.n, n) + }) + } +} + +func Test_decodeBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -139,7 +221,7 @@ func Test_DecodeBranch(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - branch, err := DecodeBranch(testCase.reader, testCase.header) + branch, err := decodeBranch(testCase.reader, testCase.header) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { @@ -150,7 +232,7 @@ func Test_DecodeBranch(t *testing.T) { } } -func Test_DecodeLeaf(t *testing.T) { +func Test_decodeLeaf(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -215,7 +297,7 @@ func Test_DecodeLeaf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - leaf, err := DecodeLeaf(testCase.reader, testCase.header) + leaf, err := decodeLeaf(testCase.reader, testCase.header) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index cc380060d6..f8ba60df3f 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -80,7 +80,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { require.NoError(t, err) header := oneBuffer[0] - resultBranch, err := DecodeBranch(buffer, header) + resultBranch, err := decodeBranch(buffer, header) require.NoError(t, err) assert.Equal(t, testCase.branchDecoded, resultBranch) diff --git a/lib/trie/database.go b/lib/trie/database.go index 2f003e701e..362720c5ce 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -79,7 +79,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - decNode, err := decodeNode(bytes.NewReader(rawNode)) + decNode, err := node.Decode(bytes.NewReader(rawNode)) if err != nil { return err } @@ -139,7 +139,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return fmt.Errorf("failed to find root key=%s: %w", root, err) } - t.root, err = decodeNode(bytes.NewReader(enc)) + t.root, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } @@ -163,7 +163,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) } - child, err = decodeNode(bytes.NewReader(enc)) + child, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } @@ -243,7 +243,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) } - rootNode, err := decodeNode(bytes.NewReader(enc)) + rootNode, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } @@ -278,7 +278,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { return nil, fmt.Errorf("failed to find node in database: %w", err) } - child, err := decodeNode(bytes.NewReader(enc)) + child, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } diff --git a/lib/trie/decode.go b/lib/trie/decode.go deleted file mode 100644 index 730e251b61..0000000000 --- a/lib/trie/decode.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/internal/trie/pools" -) - -var ( - ErrReadHeaderByte = errors.New("cannot read header byte") - ErrUnknownNodeType = errors.New("unknown node type") -) - -func decodeNode(reader io.Reader) (n Node, err error) { - buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) - defer pools.SingleByteBuffers.Put(buffer) - oneByteBuf := buffer.Bytes() - _, err = reader.Read(oneByteBuf) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) - } - header := oneByteBuf[0] - - nodeType := header >> 6 - switch nodeType { - case node.LeafType: - n, err = node.DecodeLeaf(reader, header) - if err != nil { - return nil, fmt.Errorf("cannot decode leaf: %w", err) - } - return n, nil - case node.BranchType, node.BranchWithValueType: - n, err = node.DecodeBranch(reader, header) - if err != nil { - return nil, fmt.Errorf("cannot decode branch: %w", err) - } - return n, nil - default: - return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) - } -} diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go deleted file mode 100644 index 34788bbf1a..0000000000 --- a/lib/trie/decode_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "io" - "testing" - - "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/pkg/scale" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { - encoded, err := scale.Marshal(b) - require.NoError(t, err) - return encoded -} - -func Test_decodeNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - reader io.Reader - n Node - errWrapped error - errMessage string - }{ - "no data": { - reader: bytes.NewReader(nil), - errWrapped: ErrReadHeaderByte, - errMessage: "cannot read header byte: EOF", - }, - "unknown node type": { - reader: bytes.NewReader([]byte{0}), - errWrapped: ErrUnknownNodeType, - errMessage: "unknown node type: 0", - }, - "leaf decoding error": { - reader: bytes.NewReader([]byte{ - 65, // node type 1 and key length 1 - // missing key data byte - }), - errWrapped: node.ErrReadKeyData, - errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", - }, - "leaf success": { - reader: bytes.NewReader( - append( - []byte{ - 65, // node type 1 and key length 1 - 9, // key data - }, - scaleEncodeBytes(t, 1, 2, 3)..., - ), - ), - n: &node.Leaf{ - Key: []byte{9}, - Value: []byte{1, 2, 3}, - Dirty: true, - }, - }, - "branch decoding error": { - reader: bytes.NewReader([]byte{ - 129, // node type 2 and key length 1 - // missing key data byte - }), - errWrapped: node.ErrReadKeyData, - errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", - }, - "branch success": { - reader: bytes.NewReader( - []byte{ - 129, // node type 2 and key length 1 - 9, // key data - 0, 0, // no children bitmap - }, - ), - n: &node.Branch{ - Key: []byte{9}, - Dirty: true, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - n, err := decodeNode(testCase.reader) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - assert.Equal(t, testCase.n, n) - }) - } -}