Skip to content

Commit

Permalink
chore(lib/trie): lib/trie/recorder sub-package (#2082)
Browse files Browse the repository at this point in the history
* `lib/trie/recorder` subpackage

* return an error on a call to Next() with no node

* remove recorder `IsEmpty` method

* Recorder `GetNodes()`
  • Loading branch information
qdm12 committed Dec 9, 2021
1 parent 8609bcf commit b21c447
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 45 deletions.
13 changes: 10 additions & 3 deletions lib/trie/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,27 @@ import (

"github.com/ChainSafe/gossamer/lib/trie/branch"
"github.com/ChainSafe/gossamer/lib/trie/node"
"github.com/ChainSafe/gossamer/lib/trie/record"
)

var _ recorder = (*record.Recorder)(nil)

type recorder interface {
Record(hash, rawData []byte)
}

// findAndRecord search for a desired key recording all the nodes in the path including the desired node
func findAndRecord(t *Trie, key []byte, recorder *recorder) error {
func findAndRecord(t *Trie, key []byte, recorder recorder) error {
return find(t.root, key, recorder)
}

func find(parent node.Node, key []byte, recorder *recorder) error {
func find(parent node.Node, key []byte, recorder recorder) error {
enc, hash, err := parent.EncodeAndHash()
if err != nil {
return err
}

recorder.record(hash, enc)
recorder.Record(hash, enc)

b, ok := parent.(*branch.Branch)
if !ok {
Expand Down
13 changes: 5 additions & 8 deletions lib/trie/proof.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright 2021 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package trie

import (
Expand All @@ -12,6 +9,7 @@ import (
"github.com/ChainSafe/chaindb"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie/decode"
"github.com/ChainSafe/gossamer/lib/trie/record"
)

var (
Expand Down Expand Up @@ -43,17 +41,16 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e
for _, k := range keys {
nk := decode.KeyLEToNibbles(k)

recorder := new(recorder)
recorder := record.NewRecorder()
err := findAndRecord(proofTrie, nk, recorder)
if err != nil {
return nil, err
}

for !recorder.isEmpty() {
recNode := recorder.next()
nodeHashHex := common.BytesToHex(recNode.hash)
for _, recNode := range recorder.GetNodes() {
nodeHashHex := common.BytesToHex(recNode.Hash)
if _, ok := trackedProofs[nodeHashHex]; !ok {
trackedProofs[nodeHashHex] = recNode.rawData
trackedProofs[nodeHashHex] = recNode.RawData
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions lib/trie/record/node.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package record

// Node represents a record of a visited node
type Node struct {
RawData []byte
Hash []byte
}
27 changes: 27 additions & 0 deletions lib/trie/record/recorder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2021 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package record

// Recorder records the list of nodes found by Lookup.Find
type Recorder struct {
nodes []Node
}

// NewRecorder creates a new recorder.
func NewRecorder() *Recorder {
return &Recorder{}
}

// Record appends a node to the list of visited nodes.
func (r *Recorder) Record(hash, rawData []byte) {
r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash})
}

// GetNodes returns all the nodes recorded.
// Note it does not copy its slice of nodes.
// It's fine to not copy them since the recorder
// is not used again after a call to GetNodes()
func (r *Recorder) GetNodes() (nodes []Node) {
return r.nodes
}
115 changes: 115 additions & 0 deletions lib/trie/record/recorder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package record

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_NewRecorder(t *testing.T) {
t.Parallel()

expected := &Recorder{}

recorder := NewRecorder()

assert.Equal(t, expected, recorder)
}

func Test_Recorder_Record(t *testing.T) {
testCases := map[string]struct {
recorder *Recorder
hash []byte
rawData []byte
expectedRecorder *Recorder
}{
"nil data": {
recorder: &Recorder{},
expectedRecorder: &Recorder{
nodes: []Node{
{},
},
},
},
"insert in empty recorder": {
recorder: &Recorder{},
hash: []byte{1, 2},
rawData: []byte{3, 4},
expectedRecorder: &Recorder{
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
},
},
},
"insert in non-empty recorder": {
recorder: &Recorder{
nodes: []Node{
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
},
},
hash: []byte{1, 2},
rawData: []byte{3, 4},
expectedRecorder: &Recorder{
nodes: []Node{
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
},
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

testCase.recorder.Record(testCase.hash, testCase.rawData)

assert.Equal(t, testCase.expectedRecorder, testCase.recorder)
})
}
}

func Test_Recorder_GetNodes(t *testing.T) {
testCases := map[string]struct {
recorder *Recorder
nodes []Node
}{
"no node": {
recorder: &Recorder{},
},
"get single node from recorder": {
recorder: &Recorder{
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
},
},
nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}},
},
"get node from multiple nodes in recorder": {
recorder: &Recorder{
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
{Hash: []byte{9, 6}, RawData: []byte{7, 8}},
},
},
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
{Hash: []byte{9, 6}, RawData: []byte{7, 8}},
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

nodes := testCase.recorder.GetNodes()

assert.Equal(t, testCase.nodes, nodes)
})
}
}
34 changes: 0 additions & 34 deletions lib/trie/recorder.go

This file was deleted.

0 comments on commit b21c447

Please sign in to comment.