diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 2dac9d3eeb..a74994b0da 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -110,14 +110,11 @@ func decodeBranch(reader io.Reader, variant byte, partialKeyLength uint16) ( if len(hash) < hashLength { // Handle inlined nodes reader = bytes.NewReader(hash) - variant, partialKeyLength, err := decodeHeader(reader) - if err == nil && variant == leafVariant.bits { - childNode, err = decodeLeaf(reader, partialKeyLength) - if err != nil { - return nil, fmt.Errorf("%w: at index %d: %s", - ErrDecodeValue, i, err) - } + childNode, err = Decode(reader) + if err != nil { + return nil, fmt.Errorf("decoding inlined child at index %d: %w", i, err) } + node.Descendants += childNode.Descendants } node.Descendants++ diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 2e8e0967e2..9cf9979dda 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -14,6 +14,10 @@ import ( ) func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + return scaleEncodeByteSlice(t, b) +} + +func scaleEncodeByteSlice(t *testing.T, b []byte) (encoded []byte) { encoded, err := scale.Marshal(b) require.NoError(t, err) return encoded @@ -98,66 +102,6 @@ func Test_Decode(t *testing.T) { Dirty: true, }, }, - "branch with two inlined children": { - reader: bytes.NewReader( - []byte{ - branchVariant.bits | 30, // key length 30 - // Key data start - 195, 101, 195, 207, 89, 214, - 113, 235, 114, 218, 14, 122, - 65, 19, 196, 16, 2, 80, 95, - 14, 123, 144, 18, 9, 107, - 65, 196, 235, 58, 175, - // Key data end - 148, 127, 110, 164, 41, 8, 0, 0, 104, 95, 15, 31, 5, - 21, 244, 98, 205, 207, 132, 224, 241, 214, 4, 93, 252, - 187, 32, 134, 92, 74, 43, 127, 1, 0, 0, - }, - ), - n: &Node{ - Key: []byte{ - 12, 3, 6, 5, 12, 3, - 12, 15, 5, 9, 13, 6, - 7, 1, 14, 11, 7, 2, - 13, 10, 0, 14, 7, 10, - 4, 1, 1, 3, 12, 4, - }, - Descendants: 2, - Children: []*Node{ - nil, nil, nil, nil, - { - Key: []byte{ - 14, 7, 11, 9, 0, 1, - 2, 0, 9, 6, 11, 4, - 1, 12, 4, 14, 11, - 3, 10, 10, 15, 9, - 4, 7, 15, 6, 14, - 10, 4, 2, 9, - }, - Value: []byte{0, 0}, - Dirty: true, - }, - nil, nil, nil, nil, - { - Key: []byte{ - 15, 1, 15, 0, 5, 1, - 5, 15, 4, 6, 2, 12, - 13, 12, 15, 8, 4, - 14, 0, 15, 1, 13, - 6, 0, 4, 5, 13, - 15, 12, 11, 11, - }, - Value: []byte{ - 134, 92, 74, 43, - 127, 1, 0, 0, - }, - Dirty: true, - }, - nil, nil, nil, nil, nil, nil, - }, - Dirty: true, - }, - }, } for name, testCase := range testCases { @@ -179,6 +123,13 @@ func Test_Decode(t *testing.T) { func Test_decodeBranch(t *testing.T) { t.Parallel() + const childHashLength = 32 + childHash := make([]byte, childHashLength) + for i := range childHash { + childHash[i] = byte(i) + } + scaleEncodedChildHash := scaleEncodeByteSlice(t, childHash) + testCases := map[string]struct { reader io.Reader variant byte @@ -220,9 +171,9 @@ func Test_decodeBranch(t *testing.T) { "success for branch variant": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ - {9}, // key data - {0, 4}, // children bitmap - scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash + {9}, // key data + {0, 4}, // children bitmap + scaleEncodedChildHash, }), ), variant: branchVariant.bits, @@ -233,7 +184,7 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, { - HashDigest: []byte{1, 2, 3, 4, 5}, + HashDigest: childHash, Dirty: true, }, }), @@ -255,14 +206,12 @@ func Test_decodeBranch(t *testing.T) { errMessage: "cannot decode value: EOF", }, "success for branch with value": { - reader: bytes.NewBuffer( - concatByteSlices([][]byte{ - {9}, // key data - {0, 4}, // children bitmap - scaleEncodeBytes(t, 7, 8, 9), // branch value - scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash - }), - ), + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {9}, // key data + {0, 4}, // children bitmap + scaleEncodeBytes(t, 7, 8, 9), // branch value + scaleEncodedChildHash, + })), variant: branchWithValueVariant.bits, partialKeyLength: 1, branch: &Node{ @@ -272,7 +221,7 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, { - HashDigest: []byte{1, 2, 3, 4, 5}, + HashDigest: childHash, Dirty: true, }, }), @@ -280,6 +229,63 @@ func Test_decodeBranch(t *testing.T) { Descendants: 1, }, }, + "branch with inlined node decoding error": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {1}, // key data + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncodeBytes(t, 1), // branch value + {0}, // garbage inlined node + })), + variant: branchWithValueVariant.bits, + partialKeyLength: 1, + errWrapped: io.EOF, + errMessage: "decoding inlined child at index 0: " + + "decoding header: reading header byte: EOF", + }, + "branch with inlined branch and leaf": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {1}, // key data + {0b0000_0011, 0b0000_0000}, // children bitmap + // top level inlined leaf less than 32 bytes + scaleEncodeByteSlice(t, concatByteSlices([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {2}, // key data + scaleEncodeBytes(t, 2), // value data + })), + // top level inlined branch less than 32 bytes + scaleEncodeByteSlice(t, concatByteSlices([][]byte{ + {branchWithValueVariant.bits | 1}, // partial key length of 1 + {3}, // key data + {0b0000_0001, 0b0000_0000}, // children bitmap + scaleEncodeBytes(t, 3), // branch value + // bottom level leaf + scaleEncodeByteSlice(t, concatByteSlices([][]byte{ + {leafVariant.bits | 1}, // partial key length of 1 + {4}, // key data + scaleEncodeBytes(t, 4), // value data + })), + })), + })), + variant: branchVariant.bits, + partialKeyLength: 1, + branch: &Node{ + Key: []byte{1}, + Descendants: 3, + Children: padRightChildren([]*Node{ + {Key: []byte{2}, Value: []byte{2}, Dirty: true}, + { + Key: []byte{3}, + Value: []byte{3}, + Dirty: true, + Descendants: 1, + Children: padRightChildren([]*Node{ + {Key: []byte{4}, Value: []byte{4}, Dirty: true}, + }), + }, + }), + Dirty: true, + }, + }, } for name, testCase := range testCases { diff --git a/lib/trie/database.go b/lib/trie/database.go index 4c608b9d1c..2a676882b4 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -182,7 +182,7 @@ func (t *Trie) load(db chaindb.Database, n *Node) error { hash := child.HashDigest - if len(hash) == 0 && child.Type() == node.Leaf { + if len(hash) == 0 { // node has already been loaded inline // just set encoding + hash digest _, _, err := child.EncodeAndHash(false)