Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jun 30, 2022
1 parent 0390570 commit 0bdb189
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 180 deletions.
10 changes: 10 additions & 0 deletions internal/trie/node/children.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ func (n *Node) NumChildren() (count int) {
}
return count
}

// HasChild returns true if the node has at least one child.
func (n *Node) HasChild() (has bool) {
for _, child := range n.Children {
if child != nil {
return true
}
}
return false
}
39 changes: 39 additions & 0 deletions internal/trie/node/children_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,42 @@ func Test_Node_NumChildren(t *testing.T) {
})
}
}

func Test_Node_HasChild(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
node Node
has bool
}{
"no child": {},
"one child at index 0": {
node: Node{
Children: []*Node{
{},
},
},
has: true,
},
"one child at index 1": {
node: Node{
Children: []*Node{
nil,
{},
},
},
has: true,
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

has := testCase.node.HasChild()

assert.Equal(t, testCase.has, has)
})
}
}
12 changes: 2 additions & 10 deletions lib/trie/proof/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,8 @@ func Test_walk(t *testing.T) {
errWrapped: ErrKeyNotFound,
errMessage: "key not found",
},
"parent encode and hash error": {
parent: &node.Node{
Key: make([]byte, int(^uint16(0))+63),
Value: []byte{1},
},
errWrapped: node.ErrPartialKeyTooBig,
errMessage: "encode node: " +
"cannot encode header: partial key length cannot " +
"be larger than or equal to 2^16: 65535",
},
// The parent encode error cannot be triggered here
// since it can only be caused by a buffer.Write error.
"parent leaf and empty full key": {
parent: &node.Node{
Key: []byte{1, 2},
Expand Down
68 changes: 67 additions & 1 deletion lib/trie/proof/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ package proof

import (
"bytes"
"math/rand"
"testing"

"github.com/ChainSafe/gossamer/internal/trie/node"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/stretchr/testify/require"
)

Expand All @@ -29,8 +31,72 @@ func encodeNode(t *testing.T, node node.Node) (encoded []byte) {
func blake2bNode(t *testing.T, node node.Node) (digest []byte) {
t.Helper()
encoding := encodeNode(t, node)
digestHash, err := common.Blake2bHash(encoding)
return blake2b(t, encoding)
}

func scaleEncode(t *testing.T, data []byte) (encoded []byte) {
t.Helper()
encoded, err := scale.Marshal(data)
require.NoError(t, err)
return encoded
}

func blake2b(t *testing.T, data []byte) (digest []byte) {
t.Helper()
digestHash, err := common.Blake2bHash(data)
require.NoError(t, err)
digest = digestHash[:]
return digest
}

func concatBytes(slices [][]byte) (concatenated []byte) {
for _, slice := range slices {
concatenated = append(concatenated, slice...)
}
return concatenated
}

// generateBytes generates a pseudo random byte slice
// of the given length. It uses `0` as its seed so
// calling it multiple times will generate the same
// byte slice. This is designed as such in order to have
// deterministic unit tests.
func generateBytes(t *testing.T, length uint) (bytes []byte) {
t.Helper()
generator := rand.New(rand.NewSource(0))
bytes = make([]byte, length)
_, err := generator.Read(bytes)
require.NoError(t, err)
return bytes
}

// getBadNodeEncoding returns a particular bad node encoding of 33 bytes.
func getBadNodeEncoding() (badEncoding []byte) {
return []byte{
0x1, 0x94, 0xfd, 0xc2, 0xfa, 0x2f, 0xfc, 0xc0, 0x41, 0xd3,
0xff, 0x12, 0x4, 0x5b, 0x73, 0xc8, 0x6e, 0x4f, 0xf9, 0x5f,
0xf6, 0x62, 0xa5, 0xee, 0xe8, 0x2a, 0xbd, 0xf4, 0x4a, 0x2d,
0xb, 0x75, 0xfb}
}

func Test_getBadNodeEncoding(t *testing.T) {
t.Parallel()

badEncoding := getBadNodeEncoding()
_, err := node.Decode(bytes.NewBuffer(badEncoding))
require.Error(t, err)
}

func assertLongEncoding(t *testing.T, node node.Node) {
t.Helper()

encoding := encodeNode(t, node)
require.Greater(t, len(encoding), 32)
}

func assertShortEncoding(t *testing.T, node node.Node) {
t.Helper()

encoding := encodeNode(t, node)
require.LessOrEqual(t, len(encoding), 32)
}
123 changes: 71 additions & 52 deletions lib/trie/proof/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,76 +57,72 @@ func buildTrie(encodedProofNodes [][]byte, rootHash []byte) (t *trie.Trie, err e
ErrEmptyProof, rootHash)
}

merkleValueToNode := make(map[string]*node.Node, len(encodedProofNodes))
merkleValueToEncoding := make(map[string][]byte, len(encodedProofNodes))

// This loop finds the root node and decodes it.
// The other nodes have their Merkle value (blake2b digest or the encoding itself)
// inserted into a map from merkle value to encoding. They are only decoded
// later if the root or one of its descendant node reference their Merkle value.
var root *node.Node
for i, encodedProofNode := range encodedProofNodes {
decodedNode, err := node.Decode(bytes.NewReader(encodedProofNode))
if err != nil {
return nil, fmt.Errorf("decoding node at index %d: %w (node encoded is 0x%x)",
i, err, encodedProofNode)
}

decodedNode.Encoding = encodedProofNode
// We compute the Merkle value of nodes treating them all
// as non-root nodes, meaning nodes with encoding smaller
// than 33 bytes will have their Merkle value set as their
// encoding. The Blake2b hash of the encoding is computed
// later if needed to compare with the root hash given to find
// which node is the root node.
const isRoot = false
decodedNode.HashDigest, err = node.MerkleValue(encodedProofNode, isRoot)
if err != nil {
return nil, fmt.Errorf("merkle value of node at index %d: %w", i, err)
}

proofHash := common.BytesToHex(decodedNode.HashDigest)
merkleValueToNode[proofHash] = decodedNode

if root != nil {
// Root node already found in proof
continue
}

possibleRootMerkleValue := decodedNode.HashDigest
if len(possibleRootMerkleValue) <= 32 {
// If the root merkle value is smaller than 33 bytes, it means
// it is the encoding of the node. However, the root node merkle
// value is always the blake2b digest of the node, and not its own
// encoding. Therefore, in this case we force the computation of the
// blake2b digest of the node to check if it matches the root hash given.
const isRoot = true
possibleRootMerkleValue, err = node.MerkleValue(encodedProofNode, isRoot)
for _, encodedProofNode := range encodedProofNodes {
var digest []byte
if root == nil {
// root node not found yet
digestHash, err := common.Blake2bHash(encodedProofNode)
if err != nil {
return nil, fmt.Errorf("merkle value of possible root node: %w", err)
return nil, fmt.Errorf("blake2b hash: %w", err)
}
digest = digestHash[:]

if bytes.Equal(digest, rootHash) {
root, err = node.Decode(bytes.NewReader(encodedProofNode))
if err != nil {
return nil, fmt.Errorf("decoding root node: %w", err)
}
continue // no need to add root to map of hash to encoding
}
}

if bytes.Equal(rootHash, possibleRootMerkleValue) {
decodedNode.HashDigest = rootHash
root = decodedNode
var merkleValue []byte
if len(encodedProofNode) <= 32 {
merkleValue = encodedProofNode
} else {
if digest == nil {
digestHash, err := common.Blake2bHash(encodedProofNode)
if err != nil {
return nil, fmt.Errorf("blake2b hash: %w", err)
}
digest = digestHash[:]
}
merkleValue = digest
}

merkleValueToEncoding[string(merkleValue)] = encodedProofNode
}

if root == nil {
proofMerkleValues := make([]string, 0, len(merkleValueToNode))
for merkleValue := range merkleValueToNode {
proofMerkleValues = append(proofMerkleValues, merkleValue)
proofMerkleValues := make([]string, 0, len(merkleValueToEncoding))
for merkleValueString := range merkleValueToEncoding {
merkleValueHex := common.BytesToHex([]byte(merkleValueString))
proofMerkleValues = append(proofMerkleValues, merkleValueHex)
}
return nil, fmt.Errorf("%w: for Merkle root hash 0x%x in proof Merkle value(s) %s",
ErrRootNodeNotFound, rootHash, strings.Join(proofMerkleValues, ", "))
}

loadProof(merkleValueToNode, root)
err = loadProof(merkleValueToEncoding, root)
if err != nil {
return nil, fmt.Errorf("loading proof: %w", err)
}

return trie.NewTrie(root), nil
}

// loadProof is a recursive function that will create all the trie paths based
// on the map from node hash to node starting at the root.
func loadProof(merkleValueToNode map[string]*node.Node, n *node.Node) {
func loadProof(merkleValueToEncoding map[string][]byte, n *node.Node) (err error) {
if n.Type() != node.Branch {
return
return nil
}

branch := n
Expand All @@ -135,15 +131,38 @@ func loadProof(merkleValueToNode map[string]*node.Node, n *node.Node) {
continue
}

merkleValueHex := common.BytesToHex(child.HashDigest)
node, ok := merkleValueToNode[merkleValueHex]
merkleValue := child.HashDigest
encoding, ok := merkleValueToEncoding[string(merkleValue)]
if !ok {
inlinedChild := len(child.Value) > 0 || child.HasChild()
if !inlinedChild {
// hash not found and the child is not inlined,
// so clear the child from the branch.
branch.Descendants -= 1 + child.Descendants
branch.Children[i] = nil
if !branch.HasChild() {
// Convert branch to a leaf if all its children are nil.
branch.Children = nil
}
}
continue
}

branch.Children[i] = node
loadProof(merkleValueToNode, node)
child, err := node.Decode(bytes.NewReader(encoding))
if err != nil {
return fmt.Errorf("decoding child node for Merkle value 0x%x: %w",
merkleValue, err)
}

branch.Children[i] = child
branch.Descendants += child.Descendants
err = loadProof(merkleValueToEncoding, child)
if err != nil {
return err // do not wrap error since this is recursive
}
}

return nil
}

func bytesToString(b []byte) (s string) {
Expand Down
Loading

0 comments on commit 0bdb189

Please sign in to comment.