diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 4df79c6c338..e51596dc69c 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -8,20 +8,27 @@ import ( "github.com/ChainSafe/gossamer/lib/trie/branch" "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/record" ) +var _ recorder = (*record.Recorder)(nil) + +type recorder interface { + Record(hash, rawData []byte) +} + // findAndRecord search for a desired key recording all the nodes in the path including the desired node -func findAndRecord(t *Trie, key []byte, recorder *recorder) error { +func findAndRecord(t *Trie, key []byte, recorder recorder) error { return find(t.root, key, recorder) } -func find(parent node.Node, key []byte, recorder *recorder) error { +func find(parent node.Node, key []byte, recorder recorder) error { enc, hash, err := parent.EncodeAndHash() if err != nil { return err } - recorder.record(hash, enc) + recorder.Record(hash, enc) b, ok := parent.(*branch.Branch) if !ok { diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 12c1d382a0a..094fe48f7e9 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -1,6 +1,3 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - package trie import ( @@ -12,6 +9,7 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/record" ) var ( @@ -43,17 +41,16 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e for _, k := range keys { nk := decode.KeyLEToNibbles(k) - recorder := new(recorder) + recorder := record.NewRecorder() err := findAndRecord(proofTrie, nk, recorder) if err != nil { return nil, err } - for !recorder.isEmpty() { - recNode := recorder.next() - nodeHashHex := common.BytesToHex(recNode.hash) + for _, recNode := range recorder.GetNodes() { + nodeHashHex := common.BytesToHex(recNode.Hash) if _, ok := trackedProofs[nodeHashHex]; !ok { - trackedProofs[nodeHashHex] = recNode.rawData + trackedProofs[nodeHashHex] = recNode.RawData } } } diff --git a/lib/trie/record/node.go b/lib/trie/record/node.go new file mode 100644 index 00000000000..eb3299e9bce --- /dev/null +++ b/lib/trie/record/node.go @@ -0,0 +1,7 @@ +package record + +// Node represents a record of a visited node +type Node struct { + RawData []byte + Hash []byte +} diff --git a/lib/trie/record/recorder.go b/lib/trie/record/recorder.go new file mode 100644 index 00000000000..130b434338a --- /dev/null +++ b/lib/trie/record/recorder.go @@ -0,0 +1,27 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package record + +// Recorder records the list of nodes found by Lookup.Find +type Recorder struct { + nodes []Node +} + +// NewRecorder creates a new recorder. +func NewRecorder() *Recorder { + return &Recorder{} +} + +// Record appends a node to the list of visited nodes. +func (r *Recorder) Record(hash, rawData []byte) { + r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash}) +} + +// GetNodes returns all the nodes recorded. +// Note it does not copy its slice of nodes. +// It's fine to not copy them since the recorder +// is not used again after a call to GetNodes() +func (r *Recorder) GetNodes() (nodes []Node) { + return r.nodes +} diff --git a/lib/trie/record/recorder_test.go b/lib/trie/record/recorder_test.go new file mode 100644 index 00000000000..638661b97a7 --- /dev/null +++ b/lib/trie/record/recorder_test.go @@ -0,0 +1,115 @@ +package record + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewRecorder(t *testing.T) { + t.Parallel() + + expected := &Recorder{} + + recorder := NewRecorder() + + assert.Equal(t, expected, recorder) +} + +func Test_Recorder_Record(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + hash []byte + rawData []byte + expectedRecorder *Recorder + }{ + "nil data": { + recorder: &Recorder{}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {}, + }, + }, + }, + "insert in empty recorder": { + recorder: &Recorder{}, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + "insert in non-empty recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + }, + }, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.recorder.Record(testCase.hash, testCase.rawData) + + assert.Equal(t, testCase.expectedRecorder, testCase.recorder) + }) + } +} + +func Test_Recorder_GetNodes(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + nodes []Node + }{ + "no node": { + recorder: &Recorder{}, + }, + "get single node from recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}}, + }, + "get node from multiple nodes in recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodes := testCase.recorder.GetNodes() + + assert.Equal(t, testCase.nodes, nodes) + }) + } +} diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go deleted file mode 100644 index 6db2a841d0d..00000000000 --- a/lib/trie/recorder.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -// nodeRecord represets a record of a visited node -type nodeRecord struct { - rawData []byte - hash []byte -} - -// recorder keeps the list of nodes find by Lookup.Find -type recorder []nodeRecord - -// record insert a node inside the recorded list -func (r *recorder) record(h, rd []byte) { - *r = append(*r, nodeRecord{rawData: rd, hash: h}) -} - -// next returns the current item the cursor is on and increment the cursor by 1 -func (r *recorder) next() *nodeRecord { - if !r.isEmpty() { - n := (*r)[0] - *r = (*r)[1:] - return &n - } - - return nil -} - -// isEmpty returns bool if there is data inside the slice -func (r *recorder) isEmpty() bool { - return len(*r) <= 0 -}