diff --git a/internal/trie/node/README.md b/internal/trie/node/README.md new file mode 100644 index 0000000000..cae00b0ff0 --- /dev/null +++ b/internal/trie/node/README.md @@ -0,0 +1,31 @@ +# Trie node + +Package node defines the `Node` structure with methods to be used in the modified Merkle-Patricia Radix-16 trie. + +## Codec + +The following sub-sections precise the encoding of a node. +This encoding is formally described in [the Polkadot specification](https://spec.polkadot.network/#sect-state-storage). + +### Header + +Each node encoding has a header of one or more bytes. +The first byte contains the node variant and some or all of the partial key length of the node. +If the partial key length cannot fit in the first byte, additional bytes are added to the header to represent the total partial key length. + +### Partial key + +The header is then concatenated with the partial key of the node, encoded as Little Endian bytes. + +### Remaining bytes + +The remaining bytes appended depend on the node variant. + +- For leaves, the SCALE-encoded leaf value is appended. +- For branches, the following elements are concatenated in this order and appended to the previous header+partial key: + - Children bitmap (2 bytes) + - SCALE-encoded node value + - Hash(Encoding(Child[0])) + - Hash(Encoding(Child[1])) + - ... + - Hash(Encoding(Child[15])) diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index cb6930bbee..2dac9d3eeb 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -9,63 +9,68 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) var ( - ErrReadHeaderByte = errors.New("cannot read header byte") - ErrUnknownNodeType = errors.New("unknown node type") + // ErrDecodeValue is defined since no sentinel error is defined + // in the scale package. + // TODO remove once the following issue is done: + // https://github.com/ChainSafe/gossamer/issues/2631 . ErrDecodeValue = errors.New("cannot decode value") ErrReadChildrenBitmap = errors.New("cannot read children bitmap") - ErrDecodeChildHash = errors.New("cannot decode child hash") + // ErrDecodeChildHash is defined since no sentinel error is defined + // in the scale package. + // TODO remove once the following issue is done: + // https://github.com/ChainSafe/gossamer/issues/2631 . + ErrDecodeChildHash = errors.New("cannot decode child hash") ) // Decode decodes a node from a reader. +// The encoding format is documented in the README.md +// of this package, and specified in the Polkadot spec at +// https://spec.polkadot.network/#sect-state-storage // For branch decoding, see the comments on decodeBranch. // For leaf decoding, see the comments on decodeLeaf. func Decode(reader io.Reader) (n *Node, err error) { - buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) - defer pools.SingleByteBuffers.Put(buffer) - oneByteBuf := buffer.Bytes() - _, err = reader.Read(oneByteBuf) + variant, partialKeyLength, err := decodeHeader(reader) if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + return nil, fmt.Errorf("decoding header: %w", err) } - header := oneByteBuf[0] - nodeTypeHeaderByte := header >> 6 - switch nodeTypeHeaderByte { - case leafHeader: - n, err = decodeLeaf(reader, header) + switch variant { + case leafVariant.bits: + n, err = decodeLeaf(reader, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode leaf: %w", err) } return n, nil - case branchHeader, branchWithValueHeader: - n, err = decodeBranch(reader, header) + case branchVariant.bits, branchWithValueVariant.bits: + n, err = decodeBranch(reader, variant, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode branch: %w", err) } return n, nil default: - return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeTypeHeaderByte) + // this is a programming error, an unknown node variant + // should be caught by decodeHeader. + panic(fmt.Sprintf("not implemented for node variant %08b", variant)) } } -// decodeBranch reads and decodes from a reader with the encoding specified in internal/trie/node/encode_doc.go. +// decodeBranch reads from a reader and decodes to a node branch. // Note that since the encoded branch stores the hash of the children nodes, we are not // reconstructing the child nodes from the encoding. This function instead stubs where the // children are known to be with an empty leaf. The children nodes hashes are then used to // find other values using the persistent database. -func decodeBranch(reader io.Reader, header byte) (node *Node, err error) { +func decodeBranch(reader io.Reader, variant byte, partialKeyLength uint16) ( + node *Node, err error) { node = &Node{ Dirty: true, Children: make([]*Node, ChildrenCapacity), } - keyLen := header & keyLenOffset - node.Key, err = decodeKey(reader, keyLen) + node.Key, err = decodeKey(reader, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } @@ -78,18 +83,14 @@ func decodeBranch(reader io.Reader, header byte) (node *Node, err error) { sd := scale.NewDecoder(reader) - nodeType := header >> 6 - if nodeType == branchWithValueHeader { - var value []byte - // branch w/ value - err := sd.Decode(&value) + if variant == branchWithValueVariant.bits { + err := sd.Decode(&node.Value) if err != nil { return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) } - node.Value = value } - for i := 0; i < 16; i++ { + for i := 0; i < ChildrenCapacity; i++ { if (childrenBitmap[i/8]>>(i%8))&1 != 1 { continue } @@ -101,37 +102,38 @@ func decodeBranch(reader io.Reader, header byte) (node *Node, err error) { ErrDecodeChildHash, i, err) } - // Handle inlined leaf nodes. const hashLength = 32 - nodeTypeHeaderByte := hash[0] >> 6 - if nodeTypeHeaderByte == leafHeader && len(hash) < hashLength { - leaf, err := decodeLeaf(bytes.NewReader(hash[1:]), hash[0]) - if err != nil { - return nil, fmt.Errorf("%w: at index %d: %s", - ErrDecodeValue, i, err) + childNode := &Node{ + HashDigest: hash, + Dirty: true, + } + if len(hash) < hashLength { + // Handle inlined nodes + reader = bytes.NewReader(hash) + variant, partialKeyLength, err := decodeHeader(reader) + if err == nil && variant == leafVariant.bits { + childNode, err = decodeLeaf(reader, partialKeyLength) + if err != nil { + return nil, fmt.Errorf("%w: at index %d: %s", + ErrDecodeValue, i, err) + } } - node.Descendants++ - node.Children[i] = leaf - continue } node.Descendants++ - node.Children[i] = &Node{ - HashDigest: hash, - } + node.Children[i] = childNode } return node, nil } -// decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. -func decodeLeaf(reader io.Reader, header byte) (node *Node, err error) { +// decodeLeaf reads from a reader and decodes to a leaf node. +func decodeLeaf(reader io.Reader, partialKeyLength uint16) (node *Node, err error) { node = &Node{ Dirty: true, } - keyLen := header & keyLenOffset - node.Key, err = decodeKey(reader, keyLen) + node.Key, err = decodeKey(reader, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 6a0a916b81..2e8e0967e2 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -42,28 +42,29 @@ func Test_Decode(t *testing.T) { }{ "no data": { reader: bytes.NewReader(nil), - errWrapped: ErrReadHeaderByte, - errMessage: "cannot read header byte: EOF", + errWrapped: io.EOF, + errMessage: "decoding header: reading header byte: EOF", }, - "unknown node type": { + "unknown node variant": { reader: bytes.NewReader([]byte{0}), - errWrapped: ErrUnknownNodeType, - errMessage: "unknown node type: 0", + errWrapped: ErrVariantUnknown, + errMessage: "decoding header: decoding header byte: node variant is unknown: for header byte 00000000", }, "leaf decoding error": { reader: bytes.NewReader([]byte{ - 65, // node type 1 (leaf) and key length 1 + leafVariant.bits | 1, // key length 1 // missing key data byte }), - errWrapped: ErrReadKeyData, - errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", + errWrapped: io.EOF, + errMessage: "cannot decode leaf: cannot decode key: " + + "reading from reader: EOF", }, "leaf success": { reader: bytes.NewReader( append( []byte{ - 65, // node type 1 (leaf) and key length 1 - 9, // key data + leafVariant.bits | 1, // key length 1 + 9, // key data }, scaleEncodeBytes(t, 1, 2, 3)..., ), @@ -76,18 +77,19 @@ func Test_Decode(t *testing.T) { }, "branch decoding error": { reader: bytes.NewReader([]byte{ - 129, // node type 2 (branch without value) and key length 1 + branchVariant.bits | 1, // key length 1 // missing key data byte }), - errWrapped: ErrReadKeyData, - errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", + errWrapped: io.EOF, + errMessage: "cannot decode branch: cannot decode key: " + + "reading from reader: EOF", }, "branch success": { reader: bytes.NewReader( []byte{ - 129, // node type 2 (branch without value) and key length 1 - 9, // key data - 0, 0, // no children bitmap + branchVariant.bits | 1, // key length 1 + 9, // key data + 0, 0, // no children bitmap }, ), n: &Node{ @@ -99,7 +101,7 @@ func Test_Decode(t *testing.T) { "branch with two inlined children": { reader: bytes.NewReader( []byte{ - 158, // node type 2 (branch w/o value) and key length 30 + branchVariant.bits | 30, // key length 30 // Key data start 195, 101, 195, 207, 89, 214, 113, 235, 114, 218, 14, 122, @@ -178,28 +180,31 @@ func Test_decodeBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { - reader io.Reader - header byte - branch *Node - errWrapped error - errMessage string + reader io.Reader + variant byte + partialKeyLength uint16 + branch *Node + errWrapped error + errMessage string }{ "key decoding error": { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - header: 129, // node type 2 (branch without value) and key length 1 - errWrapped: ErrReadKeyData, - errMessage: "cannot decode key: cannot read key data: EOF", + variant: branchVariant.bits, + partialKeyLength: 1, + errWrapped: io.EOF, + errMessage: "cannot decode key: reading from reader: EOF", }, "children bitmap read error": { reader: bytes.NewBuffer([]byte{ 9, // key data // missing children bitmap 2 bytes }), - header: 129, // node type 2 (branch without value) and key length 1 - errWrapped: ErrReadChildrenBitmap, - errMessage: "cannot read children bitmap: EOF", + variant: branchVariant.bits, + partialKeyLength: 1, + errWrapped: ErrReadChildrenBitmap, + errMessage: "cannot read children bitmap: EOF", }, "children decoding error": { reader: bytes.NewBuffer([]byte{ @@ -207,21 +212,21 @@ func Test_decodeBranch(t *testing.T) { 0, 4, // children bitmap // missing children scale encoded data }), - header: 129, // node type 2 (branch without value) and key length 1 - errWrapped: ErrDecodeChildHash, - errMessage: "cannot decode child hash: at index 10: EOF", + variant: branchVariant.bits, + partialKeyLength: 1, + errWrapped: ErrDecodeChildHash, + errMessage: "cannot decode child hash: at index 10: EOF", }, - "success node type 2": { + "success for branch variant": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ - { - 9, // key data - 0, 4, // children bitmap - }, + {9}, // key data + {0, 4}, // children bitmap scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), - header: 129, // node type 2 (branch without value) and key length 1 + variant: branchVariant.bits, + partialKeyLength: 1, branch: &Node{ Key: []byte{9}, Children: padRightChildren([]*Node{ @@ -229,13 +234,14 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, { HashDigest: []byte{1, 2, 3, 4, 5}, + Dirty: true, }, }), Dirty: true, Descendants: 1, }, }, - "value decoding error for node type 3": { + "value decoding error for branch with value variant": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ {9}, // key data @@ -243,11 +249,12 @@ func Test_decodeBranch(t *testing.T) { // missing encoded branch value }), ), - header: 193, // node type 3 (branch with value) and key length 1 - errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: EOF", + variant: branchWithValueVariant.bits, + partialKeyLength: 1, + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: EOF", }, - "success node type 3": { + "success for branch with value": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ {9}, // key data @@ -256,7 +263,8 @@ func Test_decodeBranch(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), - header: 193, // node type 3 (branch with value) and key length 1 + variant: branchWithValueVariant.bits, + partialKeyLength: 1, branch: &Node{ Key: []byte{9}, Value: []byte{7, 8, 9}, @@ -265,6 +273,7 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, { HashDigest: []byte{1, 2, 3, 4, 5}, + Dirty: true, }, }), Dirty: true, @@ -278,7 +287,8 @@ func Test_decodeBranch(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - branch, err := decodeBranch(testCase.reader, testCase.header) + branch, err := decodeBranch(testCase.reader, + testCase.variant, testCase.partialKeyLength) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { @@ -293,35 +303,39 @@ func Test_decodeLeaf(t *testing.T) { t.Parallel() testCases := map[string]struct { - reader io.Reader - header byte - leaf *Node - errWrapped error - errMessage string + reader io.Reader + variant byte + partialKeyLength uint16 + leaf *Node + errWrapped error + errMessage string }{ "key decoding error": { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - header: 65, // node type 1 (leaf) and key length 1 - errWrapped: ErrReadKeyData, - errMessage: "cannot decode key: cannot read key data: EOF", + variant: leafVariant.bits, + partialKeyLength: 1, + errWrapped: io.EOF, + errMessage: "cannot decode key: reading from reader: EOF", }, "value decoding error": { reader: bytes.NewBuffer([]byte{ 9, // key data 255, 255, // bad value data }), - header: 65, // node type 1 (leaf) and key length 1 - errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: could not decode invalid integer", + variant: leafVariant.bits, + partialKeyLength: 1, + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: could not decode invalid integer", }, "zero value": { reader: bytes.NewBuffer([]byte{ 9, // key data // missing value data }), - header: 65, // node type 1 (leaf) and key length 1 + variant: leafVariant.bits, + partialKeyLength: 1, leaf: &Node{ Key: []byte{9}, Dirty: true, @@ -334,7 +348,8 @@ func Test_decodeLeaf(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data }), ), - header: 65, // node type 1 (leaf) and key length 1 + variant: leafVariant.bits, + partialKeyLength: 1, leaf: &Node{ Key: []byte{9}, Value: []byte{1, 2, 3, 4, 5}, @@ -348,7 +363,8 @@ func Test_decodeLeaf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - leaf, err := decodeLeaf(testCase.reader, testCase.header) + leaf, err := decodeLeaf(testCase.reader, + testCase.partialKeyLength) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/encode.go b/internal/trie/node/encode.go index c7890e16a8..5bea739c0c 100644 --- a/internal/trie/node/encode.go +++ b/internal/trie/node/encode.go @@ -12,7 +12,9 @@ import ( ) // Encode encodes the node to the buffer given. -// The encoding format is documented in encode_doc.go. +// The encoding format is documented in the README.md +// of this package, and specified in the Polkadot spec at +// https://spec.polkadot.network/#sect-state-storage func (n *Node) Encode(buffer Buffer) (err error) { if !n.Dirty && n.Encoding != nil { _, err = buffer.Write(n.Encoding) diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index 8c6757b4ef..c92a1a2751 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -93,6 +93,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { 14, 15, 16, 17, 10, 11, 12, 13, }, + Dirty: true, }, }), }, @@ -109,6 +110,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { 21, 186, 226, 204, 145, 132, 5, 39, 204, }, + Dirty: true, }, }), Dirty: true, @@ -127,12 +129,10 @@ func Test_Branch_Encode_Decode(t *testing.T) { err := testCase.branchToEncode.Encode(buffer) require.NoError(t, err) - oneBuffer := make([]byte, 1) - _, err = buffer.Read(oneBuffer) + variant, partialKeyLength, err := decodeHeader(buffer) require.NoError(t, err) - header := oneBuffer[0] - resultBranch, err := decodeBranch(buffer, header) + resultBranch, err := decodeBranch(buffer, variant, partialKeyLength) require.NoError(t, err) assert.Equal(t, testCase.branchDecoded, resultBranch) diff --git a/internal/trie/node/encode_doc.go b/internal/trie/node/encode_doc.go deleted file mode 100644 index 1a8b6a1c0a..0000000000 --- a/internal/trie/node/encode_doc.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package node - -//nolint:lll -// Modified Merkle-Patricia Trie -// See https://github.com/w3f/polkadot-spec/blob/master/runtime-environment-spec/polkadot_re_spec.pdf for the full specification. -// -// Note that for the following definitions, `|` denotes concatenation -// -// Branch encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 10 if branch w/o value, 11 if branch w/ value -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the branch's key -// `Value` is: Children Bitmap | SCALE Branch node Value | Hash(Enc(Child[i_1])) | Hash(Enc(Child[i_2])) | ... | Hash(Enc(Child[i_n])) -// -// Leaf encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 01 -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the leaf's key -// `Value` is the leaf's SCALE encoded value diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index e57c13902b..ea6a4fb47e 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -59,15 +59,16 @@ func Test_Node_Encode(t *testing.T) { }, "leaf header encoding error": { node: &Node{ - Key: make([]byte, 63+(1<<16)), + Key: make([]byte, 1), }, writes: []writeCall{ { - written: []byte{127}, + written: []byte{leafVariant.bits | 1}, + err: errTest, }, }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + wrappedErr: errTest, + errMessage: "cannot encode header: test error", }, "leaf buffer write error for encoded key": { node: &Node{ @@ -75,10 +76,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{67}, + written: []byte{leafVariant.bits | 3}, // partial key length 3 }, { - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, err: errTest, }, }, @@ -92,10 +93,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{67}, + written: []byte{leafVariant.bits | 3}, // partial key length 3 }, { - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { written: []byte{12, 4, 5, 6}, @@ -112,10 +113,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{67}, + written: []byte{leafVariant.bits | 3}, // partial key length 3 }, { - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { written: []byte{12, 4, 5, 6}, @@ -153,15 +154,16 @@ func Test_Node_Encode(t *testing.T) { "branch header encoding error": { node: &Node{ Children: make([]*Node, ChildrenCapacity), - Key: make([]byte, 63+(1<<16)), + Key: make([]byte, 1), }, writes: []writeCall{ { // header - written: []byte{191}, + written: []byte{branchVariant.bits | 1}, // partial key length 1 + err: errTest, }, }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + wrappedErr: errTest, + errMessage: "cannot encode header: test error", }, "buffer write error for encoded key": { node: &Node{ @@ -171,10 +173,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{195}, + written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, err: errTest, }, }, @@ -192,10 +194,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{195}, + written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { // children bitmap written: []byte{136, 0}, @@ -216,10 +218,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{195}, + written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { // children bitmap written: []byte{136, 0}, @@ -243,10 +245,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{195}, + written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { // children bitmap written: []byte{136, 0}, @@ -275,10 +277,10 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{195}, + written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE - written: []byte{1, 35}, + written: []byte{0x01, 0x23}, }, { // children bitmap written: []byte{136, 0}, diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go index 5177b6f10c..033c5e84e7 100644 --- a/internal/trie/node/header.go +++ b/internal/trie/node/header.go @@ -4,44 +4,151 @@ package node import ( + "errors" + "fmt" "io" ) -const ( - leafHeader byte = 1 // 01 - branchHeader byte = 2 // 10 - branchWithValueHeader byte = 3 // 11 -) - -const ( - keyLenOffset = 0x3f - nodeHeaderShift = 6 -) - // encodeHeader writes the encoded header for the node. func encodeHeader(node *Node, writer io.Writer) (err error) { - var header byte + partialKeyLength := len(node.Key) + if partialKeyLength > int(maxPartialKeyLength) { + panic(fmt.Sprintf("partial key length is too big: %d", partialKeyLength)) + } + + // Merge variant byte and partial key length together + var variant variant if node.Type() == Leaf { - header = leafHeader + variant = leafVariant } else if node.Value == nil { - header = branchHeader + variant = branchVariant } else { - header = branchWithValueHeader + variant = branchWithValueVariant } - header <<= nodeHeaderShift - if len(node.Key) < keyLenOffset { - header |= byte(len(node.Key)) - _, err = writer.Write([]byte{header}) + buffer := make([]byte, 1) + buffer[0] = variant.bits + partialKeyLengthMask := ^variant.mask + + if partialKeyLength < int(partialKeyLengthMask) { + // Partial key length fits in header byte + buffer[0] |= byte(partialKeyLength) + _, err = writer.Write(buffer) return err } - header = header | keyLenOffset - _, err = writer.Write([]byte{header}) + // Partial key length does not fit in header byte only + buffer[0] |= partialKeyLengthMask + partialKeyLength -= int(partialKeyLengthMask) + _, err = writer.Write(buffer) if err != nil { return err } - err = encodeKeyLength(len(node.Key), writer) - return err + for { + buffer[0] = 255 + if partialKeyLength < 255 { + buffer[0] = byte(partialKeyLength) + } + + _, err = writer.Write(buffer) + if err != nil { + return err + } + + partialKeyLength -= int(buffer[0]) + + if buffer[0] < 255 { + break + } + } + + return nil +} + +var ( + ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than 2^16") +) + +func decodeHeader(reader io.Reader) (variant byte, + partialKeyLength uint16, err error) { + buffer := make([]byte, 1) + _, err = reader.Read(buffer) + if err != nil { + return 0, 0, fmt.Errorf("reading header byte: %w", err) + } + + variant, partialKeyLengthHeader, partialKeyLengthHeaderMask, + err := decodeHeaderByte(buffer[0]) + if err != nil { + return 0, 0, fmt.Errorf("decoding header byte: %w", err) + } + + partialKeyLength = uint16(partialKeyLengthHeader) + if partialKeyLengthHeader < partialKeyLengthHeaderMask { + // partial key length is contained in the first byte. + return variant, partialKeyLength, nil + } + + // the partial key length header byte is equal to its maximum + // possible value; this means the partial key length is greater + // than this (0 to 2^6 - 1 = 63) maximum value, and we need to + // accumulate the next bytes from the reader to get the full + // partial key length. + // Specification: https://spec.polkadot.network/#defn-node-header + var previousKeyLength uint16 // used to track an eventual overflow + for { + _, err = reader.Read(buffer) + if err != nil { + return 0, 0, fmt.Errorf("reading key length: %w", err) + } + + previousKeyLength = partialKeyLength + partialKeyLength += uint16(buffer[0]) + + if partialKeyLength < previousKeyLength { + // the partial key can have a length up to 65535 which is the + // maximum uint16 value; therefore if we overflowed, we went over + // this maximum. + overflowed := maxPartialKeyLength - previousKeyLength + partialKeyLength + return 0, 0, fmt.Errorf("%w: overflowed by %d", ErrPartialKeyTooBig, overflowed) + } + + if buffer[0] < 255 { + // the end of the partial key length has been reached. + return variant, partialKeyLength, nil + } + } +} + +var ErrVariantUnknown = errors.New("node variant is unknown") + +func decodeHeaderByte(header byte) (variantBits, + partialKeyLengthHeader, partialKeyLengthHeaderMask byte, err error) { + // variants is a slice of all variants sorted in ascending + // order by the number of bits each variant mask occupy + // in the header byte. + // See https://spec.polkadot.network/#defn-node-header + // Performance note: see `Benchmark_decodeHeaderByte`; + // running with a locally scoped slice is as fast as having + // it at global scope. + variants := []variant{ + leafVariant, // mask 1100_0000 + branchVariant, // mask 1100_0000 + branchWithValueVariant, // mask 1100_0000 + } + + for i := len(variants) - 1; i >= 0; i-- { + variantBits = header & variants[i].mask + if variantBits != variants[i].bits { + continue + } + + partialKeyLengthHeaderMask = ^variants[i].mask + partialKeyLengthHeader = header & partialKeyLengthHeaderMask + return variantBits, partialKeyLengthHeader, + partialKeyLengthHeaderMask, nil + } + + return 0, 0, 0, fmt.Errorf("%w: for header byte %08b", ErrVariantUnknown, header) } diff --git a/internal/trie/node/header_test.go b/internal/trie/node/header_test.go index 1ed826483a..8c572bbaf2 100644 --- a/internal/trie/node/header_test.go +++ b/internal/trie/node/header_test.go @@ -4,10 +4,14 @@ package node import ( + "bytes" + "io" + "math" "testing" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_encodeHeader(t *testing.T) { @@ -22,7 +26,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0x80}}, + {written: []byte{branchVariant.bits}}, }, }, "branch with value": { @@ -31,7 +35,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0xc0}}, + {written: []byte{branchWithValueVariant.bits}}, }, }, "branch with key of length 30": { @@ -40,7 +44,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0x9e}}, + {written: []byte{branchVariant.bits | 30}}, }, }, "branch with key of length 62": { @@ -49,7 +53,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0xbe}}, + {written: []byte{branchVariant.bits | 62}}, }, }, "branch with key of length 63": { @@ -58,8 +62,9 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0xbf}}, - {written: []byte{0x0}}, + {written: []byte{branchVariant.bits | 63}}, + {written: []byte{0x00}}, // trailing 0 to indicate the partial + // key length is done here. }, }, "branch with key of length 64": { @@ -68,28 +73,17 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{0xbf}}, - {written: []byte{0x1}}, + {written: []byte{branchVariant.bits | 63}}, + {written: []byte{0x01}}, }, }, - "branch with key too big": { - node: &Node{ - Key: make([]byte, 65535+63), - Children: make([]*Node, ChildrenCapacity), - }, - writes: []writeCall{ - {written: []byte{0xbf}}, - }, - errWrapped: ErrPartialKeyTooBig, - errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", - }, "branch with small key length write error": { node: &Node{ Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ { - written: []byte{0x80}, + written: []byte{branchVariant.bits}, err: errTest, }, }, @@ -98,12 +92,15 @@ func Test_encodeHeader(t *testing.T) { }, "branch with long key length write error": { node: &Node{ - Key: make([]byte, 64), + Key: make([]byte, int(^branchVariant.mask)+1), Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ { - written: []byte{0xbf}, + written: []byte{branchVariant.bits | ^branchVariant.mask}, + }, + { + written: []byte{0x01}, err: errTest, }, }, @@ -113,7 +110,7 @@ func Test_encodeHeader(t *testing.T) { "leaf with no key": { node: &Node{}, writes: []writeCall{ - {written: []byte{0x40}}, + {written: []byte{leafVariant.bits}}, }, }, "leaf with key of length 30": { @@ -121,7 +118,7 @@ func Test_encodeHeader(t *testing.T) { Key: make([]byte, 30), }, writes: []writeCall{ - {written: []byte{0x5e}}, + {written: []byte{leafVariant.bits | 30}}, }, }, "leaf with short key write error": { @@ -130,19 +127,19 @@ func Test_encodeHeader(t *testing.T) { }, writes: []writeCall{ { - written: []byte{0x5e}, + written: []byte{leafVariant.bits | 30}, err: errTest, }, }, errWrapped: errTest, - errMessage: errTest.Error(), + errMessage: "test error", }, "leaf with key of length 62": { node: &Node{ Key: make([]byte, 62), }, writes: []writeCall{ - {written: []byte{0x7e}}, + {written: []byte{leafVariant.bits | 62}}, }, }, "leaf with key of length 63": { @@ -150,7 +147,7 @@ func Test_encodeHeader(t *testing.T) { Key: make([]byte, 63), }, writes: []writeCall{ - {written: []byte{0x7f}}, + {written: []byte{leafVariant.bits | 63}}, {written: []byte{0x0}}, }, }, @@ -159,7 +156,7 @@ func Test_encodeHeader(t *testing.T) { Key: make([]byte, 64), }, writes: []writeCall{ - {written: []byte{0x7f}}, + {written: []byte{leafVariant.bits | 63}}, {written: []byte{0x1}}, }, }, @@ -169,22 +166,32 @@ func Test_encodeHeader(t *testing.T) { }, writes: []writeCall{ { - written: []byte{0x7f}, + written: []byte{leafVariant.bits | 63}, err: errTest, }, }, errWrapped: errTest, - errMessage: errTest.Error(), + errMessage: "test error", }, - "leaf with key too big": { + "leaf with key length over 3 bytes": { node: &Node{ - Key: make([]byte, 65535+63), + Key: make([]byte, int(^leafVariant.mask)+0b1111_1111+0b0000_0001), }, writes: []writeCall{ - {written: []byte{0x7f}}, + {written: []byte{leafVariant.bits | ^leafVariant.mask}}, + {written: []byte{0b1111_1111}}, + {written: []byte{0b0000_0001}}, + }, + }, + "leaf with key length over 3 bytes and last byte zero": { + node: &Node{ + Key: make([]byte, int(^leafVariant.mask)+0b1111_1111), + }, + writes: []writeCall{ + {written: []byte{leafVariant.bits | ^leafVariant.mask}}, + {written: []byte{0b1111_1111}}, + {written: []byte{0x00}}, }, - errWrapped: ErrPartialKeyTooBig, - errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", }, } @@ -215,4 +222,211 @@ func Test_encodeHeader(t *testing.T) { } }) } + + t.Run("partial key length is too big", func(t *testing.T) { + t.Parallel() + + const keyLength = uint(maxPartialKeyLength) + 1 + node := &Node{ + Key: make([]byte, keyLength), + } + + assert.PanicsWithValue(t, "partial key length is too big: 65536", func() { + _ = encodeHeader(node, io.Discard) + }) + }) +} + +func Test_encodeHeader_At_Maximum(t *testing.T) { + t.Parallel() + + // Note: this test case cannot run with the + // mock writer since it's too slow, so we use + // an actual buffer. + + variant := leafVariant.bits + const partialKeyLengthHeaderMask = 0b0011_1111 + const keyLength = uint(maxPartialKeyLength) + extraKeyBytesNeeded := math.Ceil(float64(maxPartialKeyLength-partialKeyLengthHeaderMask) / 255.0) + expectedEncodingLength := 1 + int(extraKeyBytesNeeded) + + lengthLeft := maxPartialKeyLength + expectedBytes := make([]byte, expectedEncodingLength) + expectedBytes[0] = variant | partialKeyLengthHeaderMask + lengthLeft -= partialKeyLengthHeaderMask + for i := 1; i < len(expectedBytes)-1; i++ { + expectedBytes[i] = 255 + lengthLeft -= 255 + } + expectedBytes[len(expectedBytes)-1] = byte(lengthLeft) + + buffer := bytes.NewBuffer(nil) + buffer.Grow(expectedEncodingLength) + + node := &Node{ + Key: make([]byte, keyLength), + } + + err := encodeHeader(node, buffer) + + require.NoError(t, err) + assert.Equal(t, expectedBytes, buffer.Bytes()) +} + +func Test_decodeHeader(t *testing.T) { + testCases := map[string]struct { + reads []readCall + variant byte + partialKeyLength uint16 + errWrapped error + errMessage string + }{ + "first byte read error": { + reads: []readCall{ + {buffArgCap: 1, err: errTest}, + }, + errWrapped: errTest, + errMessage: "reading header byte: test error", + }, + "header byte decoding error": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{0b0011_1110}}, + }, + errWrapped: ErrVariantUnknown, + errMessage: "decoding header byte: node variant is unknown: for header byte 00111110", + }, + "partial key length contained in first byte": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1110}}, + }, + variant: leafVariant.bits, + partialKeyLength: uint16(0b0011_1110), + }, + "long partial key length and second byte read error": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, err: errTest}, + }, + errWrapped: errTest, + errMessage: "reading key length: test error", + }, + "partial key length spread on multiple bytes": { + reads: []readCall{ + {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, read: []byte{0b1111_1111}}, + {buffArgCap: 1, read: []byte{0b1111_0000}}, + }, + variant: leafVariant.bits, + partialKeyLength: uint16(0b0011_1111 + 0b1111_1111 + 0b1111_0000), + }, + "partial key length too long": { + reads: repeatReadCall(readCall{ + buffArgCap: 1, + read: []byte{0b1111_1111}, + }, 258), + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than 2^16: overflowed by 254", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + reader := NewMockReader(ctrl) + var previousCall *gomock.Call + for _, readCall := range testCase.reads { + readCall := readCall // required variable pinning + byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) + call := reader.EXPECT().Read(byteSliceCapMatcher). + DoAndReturn(func(b []byte) (n int, err error) { + copy(b, readCall.read) + return readCall.n, readCall.err + }) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + variant, partialKeyLength, err := decodeHeader(reader) + + assert.Equal(t, testCase.variant, variant) + assert.Equal(t, int(testCase.partialKeyLength), int(partialKeyLength)) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Test_decodeHeaderByte(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + header byte + variantBits byte + partialKeyLengthHeader byte + partialKeyLengthHeaderMask byte + errWrapped error + errMessage string + }{ + "branch with value header": { + header: 0b1110_1001, + variantBits: 0b1100_0000, + partialKeyLengthHeader: 0b0010_1001, + partialKeyLengthHeaderMask: 0b0011_1111, + }, + "branch header": { + header: 0b1010_1001, + variantBits: 0b1000_0000, + partialKeyLengthHeader: 0b0010_1001, + partialKeyLengthHeaderMask: 0b0011_1111, + }, + "leaf header": { + header: 0b0110_1001, + variantBits: 0b0100_0000, + partialKeyLengthHeader: 0b0010_1001, + partialKeyLengthHeaderMask: 0b0011_1111, + }, + "unknown variant header": { + header: 0b0000_0000, + errWrapped: ErrVariantUnknown, + errMessage: "node variant is unknown: for header byte 00000000", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + variantBits, partialKeyLengthHeader, + partialKeyLengthHeaderMask, err := decodeHeaderByte(testCase.header) + + assert.Equal(t, testCase.variantBits, variantBits) + assert.Equal(t, testCase.partialKeyLengthHeader, partialKeyLengthHeader) + assert.Equal(t, testCase.partialKeyLengthHeaderMask, partialKeyLengthHeaderMask) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} + +func Benchmark_decodeHeaderByte(b *testing.B) { + // With global scoped variants slice: + // 3.453 ns/op 0 B/op 0 allocs/op + // With locally scoped variants slice: + // 3.441 ns/op 0 B/op 0 allocs/op + header := leafVariant.bits | 0b0000_0001 + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _, _ = decodeHeaderByte(header) + } } diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index 3b450513bb..343a5d747d 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -4,92 +4,31 @@ package node import ( - "bytes" "errors" "fmt" "io" "github.com/ChainSafe/gossamer/internal/trie/codec" - "github.com/ChainSafe/gossamer/internal/trie/pools" ) -const maxPartialKeySize = ^uint16(0) +const maxPartialKeyLength = ^uint16(0) -var ( - ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") - ErrReadKeyLength = errors.New("cannot read key length") - ErrReadKeyData = errors.New("cannot read key data") -) - -// encodeKeyLength encodes the key length. -func encodeKeyLength(keyLength int, writer io.Writer) (err error) { - keyLength -= 63 - - if keyLength >= int(maxPartialKeySize) { - return fmt.Errorf("%w: %d", - ErrPartialKeyTooBig, keyLength) - } - - for i := uint16(0); i < maxPartialKeySize; i++ { - if keyLength < 255 { - _, err = writer.Write([]byte{byte(keyLength)}) - if err != nil { - return err - } - break - } - _, err = writer.Write([]byte{255}) - if err != nil { - return err - } - - keyLength -= 255 - } - - return nil -} +var ErrReaderMismatchCount = errors.New("read unexpected number of bytes from reader") // decodeKey decodes a key from a reader. -func decodeKey(reader io.Reader, keyLengthByte byte) (b []byte, err error) { - keyLength := int(keyLengthByte) - - if keyLengthByte == keyLenOffset { - // partial key longer than 63, read next bytes for rest of pk len - buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) - defer pools.SingleByteBuffers.Put(buffer) - oneByteBuf := buffer.Bytes() - for { - _, err = reader.Read(oneByteBuf) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadKeyLength, err) - } - nextKeyLen := oneByteBuf[0] - - keyLength += int(nextKeyLen) - - if nextKeyLen < 0xff { - break - } - - if keyLength >= int(maxPartialKeySize) { - return nil, fmt.Errorf("%w: %d", - ErrPartialKeyTooBig, keyLength) - } - } - } - - if keyLength == 0 { +func decodeKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { + if partialKeyLength == 0 { return []byte{}, nil } - key := make([]byte, keyLength/2+keyLength%2) + key := make([]byte, partialKeyLength/2+partialKeyLength%2) n, err := reader.Read(key) if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadKeyData, err) + return nil, fmt.Errorf("reading from reader: %w", err) } else if n != len(key) { - return nil, fmt.Errorf("%w: read %d bytes instead of %d", - ErrReadKeyData, n, len(key)) + return nil, fmt.Errorf("%w: read %d bytes instead of expected %d bytes", + ErrReaderMismatchCount, n, len(key)) } - return codec.KeyLEToNibbles(key)[keyLength%2:], nil + return codec.KeyLEToNibbles(key)[partialKeyLength%2:], nil } diff --git a/internal/trie/node/key_test.go b/internal/trie/node/key_test.go index 2e21825cce..930a97c656 100644 --- a/internal/trie/node/key_test.go +++ b/internal/trie/node/key_test.go @@ -4,13 +4,11 @@ package node import ( - "bytes" "fmt" "testing" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func repeatBytes(n int, b byte) (slice []byte) { @@ -21,129 +19,6 @@ func repeatBytes(n int, b byte) (slice []byte) { return slice } -func Test_encodeKeyLength(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - keyLength int - writes []writeCall - errWrapped error - errMessage string - }{ - "length equal to maximum": { - keyLength: int(maxPartialKeySize) + 63, - errWrapped: ErrPartialKeyTooBig, - errMessage: "partial key length cannot be " + - "larger than or equal to 2^16: 65535", - }, - "zero length": { - writes: []writeCall{ - { - written: []byte{0xc1}, - }, - }, - }, - "one length": { - keyLength: 1, - writes: []writeCall{ - { - written: []byte{0xc2}, - }, - }, - }, - "error at single byte write": { - keyLength: 1, - writes: []writeCall{ - { - written: []byte{0xc2}, - err: errTest, - }, - }, - errWrapped: errTest, - errMessage: errTest.Error(), - }, - "error at first byte write": { - keyLength: 255 + 100 + 63, - writes: []writeCall{ - { - written: []byte{255}, - err: errTest, - }, - }, - errWrapped: errTest, - errMessage: errTest.Error(), - }, - "error at last byte write": { - keyLength: 255 + 100 + 63, - writes: []writeCall{ - { - written: []byte{255}, - }, - { - written: []byte{100}, - err: errTest, - }, - }, - errWrapped: errTest, - errMessage: errTest.Error(), - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - writer := NewMockWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := writer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if write.err != nil { - break - } else if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeKeyLength(testCase.keyLength, writer) - - assert.ErrorIs(t, err, testCase.errWrapped) - if testCase.errWrapped != nil { - assert.EqualError(t, err, testCase.errMessage) - } - }) - } - - t.Run("length at maximum", func(t *testing.T) { - t.Parallel() - - // Note: this test case cannot run with the - // mock writer since it's too slow, so we use - // an actual buffer. - - const keyLength = int(maxPartialKeySize) + 62 - const expectedEncodingLength = 257 - expectedBytes := make([]byte, expectedEncodingLength) - for i := 0; i < len(expectedBytes)-1; i++ { - expectedBytes[i] = 255 - } - expectedBytes[len(expectedBytes)-1] = 254 - - buffer := bytes.NewBuffer(nil) - buffer.Grow(expectedEncodingLength) - - err := encodeKeyLength(keyLength, buffer) - - require.NoError(t, err) - assert.Equal(t, expectedBytes, buffer.Bytes()) - }) -} - //go:generate mockgen -destination=reader_mock_test.go -package $GOPACKAGE io Reader type readCall struct { @@ -153,20 +28,12 @@ type readCall struct { err error } -func repeatReadCalls(rc readCall, length int) (readCalls []readCall) { - readCalls = make([]readCall, length) - for i := range readCalls { - readCalls[i] = readCall{ - buffArgCap: rc.buffArgCap, - n: rc.n, - err: rc.err, - } - if rc.read != nil { - readCalls[i].read = make([]byte, len(rc.read)) - copy(readCalls[i].read, rc.read) - } +func repeatReadCall(base readCall, n int) (calls []readCall) { + calls = make([]readCall, n) + for i := range calls { + calls[i] = base } - return readCalls + return calls } var _ gomock.Matcher = (*byteSliceCapMatcher)(nil) @@ -184,7 +51,7 @@ func (b *byteSliceCapMatcher) Matches(x interface{}) bool { } func (b *byteSliceCapMatcher) String() string { - return fmt.Sprintf("capacity of slice is not the expected capacity %d", b.capacity) + return fmt.Sprintf("slice with capacity %d", b.capacity) } func newByteSliceCapMatcher(capacity int) *byteSliceCapMatcher { @@ -197,45 +64,45 @@ func Test_decodeKey(t *testing.T) { t.Parallel() testCases := map[string]struct { - reads []readCall - keyLength byte - b []byte - errWrapped error - errMessage string + reads []readCall + partialKeyLength uint16 + b []byte + errWrapped error + errMessage string }{ "zero key length": { - b: []byte{}, + partialKeyLength: 0, + b: []byte{}, }, "short key length": { reads: []readCall{ {buffArgCap: 3, read: []byte{1, 2, 3}, n: 3}, }, - keyLength: 5, - b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, + partialKeyLength: 5, + b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, }, "key read error": { reads: []readCall{ {buffArgCap: 3, err: errTest}, }, - keyLength: 5, - errWrapped: ErrReadKeyData, - errMessage: "cannot read key data: test error", + partialKeyLength: 5, + errWrapped: errTest, + errMessage: "reading from reader: test error", }, "key read bytes count mismatch": { reads: []readCall{ {buffArgCap: 3, n: 2}, }, - keyLength: 5, - errWrapped: ErrReadKeyData, - errMessage: "cannot read key data: read 2 bytes instead of 3", + partialKeyLength: 5, + errWrapped: ErrReaderMismatchCount, + errMessage: "read unexpected number of bytes from reader: read 2 bytes instead of expected 3 bytes", }, "long key length": { reads: []readCall{ - {buffArgCap: 1, read: []byte{6}, n: 1}, // key length {buffArgCap: 35, read: repeatBytes(35, 7), n: 35}, // key data }, - keyLength: 0x3f, + partialKeyLength: 70, b: []byte{ 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, @@ -245,20 +112,6 @@ func Test_decodeKey(t *testing.T) { 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, }, - "key length read error": { - reads: []readCall{ - {buffArgCap: 1, err: errTest}, - }, - keyLength: 0x3f, - errWrapped: ErrReadKeyLength, - errMessage: "cannot read key length: test error", - }, - "key length too big": { - reads: repeatReadCalls(readCall{buffArgCap: 1, read: []byte{0xff}, n: 1}, 257), - keyLength: 0x3f, - errWrapped: ErrPartialKeyTooBig, - errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", - }, } for name, testCase := range testCases { @@ -270,6 +123,7 @@ func Test_decodeKey(t *testing.T) { reader := NewMockReader(ctrl) var previousCall *gomock.Call for _, readCall := range testCase.reads { + readCall := readCall // required variable pinning byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) call := reader.EXPECT().Read(byteSliceCapMatcher). DoAndReturn(func(b []byte) (n int, err error) { @@ -282,7 +136,7 @@ func Test_decodeKey(t *testing.T) { previousCall = call } - b, err := decodeKey(reader, testCase.keyLength) + b, err := decodeKey(reader, testCase.partialKeyLength) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index 493ca1de91..a40cf31fd7 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -1,6 +1,8 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only +// Package node defines the `Node` structure with methods +// to be used in the modified Merkle-Patricia Radix-16 trie. package node import ( diff --git a/internal/trie/node/variants.go b/internal/trie/node/variants.go new file mode 100644 index 0000000000..2c75c44904 --- /dev/null +++ b/internal/trie/node/variants.go @@ -0,0 +1,26 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +type variant struct { + bits byte + mask byte +} + +// Node variants +// See https://spec.polkadot.network/#defn-node-header +var ( + leafVariant = variant{ // leaf 01 + bits: 0b0100_0000, + mask: 0b1100_0000, + } + branchVariant = variant{ // branch 10 + bits: 0b1000_0000, + mask: 0b1100_0000, + } + branchWithValueVariant = variant{ // branch 11 + bits: 0b1100_0000, + mask: 0b1100_0000, + } +) diff --git a/internal/trie/pools/pools.go b/internal/trie/pools/pools.go index 855232ef44..1bfe8f5a83 100644 --- a/internal/trie/pools/pools.go +++ b/internal/trie/pools/pools.go @@ -10,15 +10,6 @@ import ( "golang.org/x/crypto/blake2b" ) -// SingleByteBuffers is a sync pool of buffers of capacity 1. -var SingleByteBuffers = &sync.Pool{ - New: func() interface{} { - const bufferLength = 1 - b := make([]byte, bufferLength) - return bytes.NewBuffer(b) - }, -} - // DigestBuffers is a sync pool of buffers of capacity 32. var DigestBuffers = &sync.Pool{ New: func() interface{} {