Skip to content

Commit

Permalink
Recorder GetNodes()
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Nov 30, 2021
1 parent 4078ae1 commit ba6cc40
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 49 deletions.
8 changes: 1 addition & 7 deletions lib/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,7 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e
return nil, err
}

for {
recNode, err := recorder.Next()
if errors.Is(err, record.ErrNoNextNode) {
break
} else if err != nil {
return nil, fmt.Errorf("recorder failed for key 0x%x: %w", k, err)
}
for _, recNode := range recorder.GetNodes() {
nodeHashHex := common.BytesToHex(recNode.Hash)
if _, ok := trackedProofs[nodeHashHex]; !ok {
trackedProofs[nodeHashHex] = recNode.RawData
Expand Down
23 changes: 6 additions & 17 deletions lib/trie/record/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@

package record

import "errors"

var (
ErrNoNextNode = errors.New("no next node")
)

// Recorder records the list of nodes found by Lookup.Find
type Recorder struct {
nodes []Node
Expand All @@ -24,15 +18,10 @@ func (r *Recorder) Record(hash, rawData []byte) {
r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash})
}

// Next returns the first node in the recorded list
// and removes it (shift operation).
func (r *Recorder) Next() (node Node, err error) {
if len(r.nodes) == 0 {
return node, ErrNoNextNode
}

node = r.nodes[0]
r.nodes = r.nodes[1:]

return node, nil
// GetNodes returns all the nodes recorded.
// Note it does not copy its slice of nodes.
// It's fine to not copy them since the recorder
// is not used again after a call to GetNodes()
func (r *Recorder) GetNodes() (nodes []Node) {
return r.nodes
}
36 changes: 11 additions & 25 deletions lib/trie/record/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,21 @@ func Test_Recorder_Record(t *testing.T) {
}
}

func Test_Recorder_Next(t *testing.T) {
func Test_Recorder_GetNodes(t *testing.T) {
testCases := map[string]struct {
recorder *Recorder
node Node
err error
expectedRecorder *Recorder
recorder *Recorder
nodes []Node
}{
"no node": {
recorder: &Recorder{},
err: ErrNoNextNode,
expectedRecorder: &Recorder{},
recorder: &Recorder{},
},
"get single node from recorder": {
recorder: &Recorder{
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
},
},
node: Node{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
expectedRecorder: &Recorder{
nodes: []Node{},
},
nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}},
},
"get node from multiple nodes in recorder": {
recorder: &Recorder{
Expand All @@ -101,12 +94,10 @@ func Test_Recorder_Next(t *testing.T) {
{Hash: []byte{9, 6}, RawData: []byte{7, 8}},
},
},
node: Node{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
expectedRecorder: &Recorder{
nodes: []Node{
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
{Hash: []byte{9, 6}, RawData: []byte{7, 8}},
},
nodes: []Node{
{Hash: []byte{1, 2}, RawData: []byte{3, 4}},
{Hash: []byte{5, 6}, RawData: []byte{7, 8}},
{Hash: []byte{9, 6}, RawData: []byte{7, 8}},
},
},
}
Expand All @@ -116,14 +107,9 @@ func Test_Recorder_Next(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()

node, err := testCase.recorder.Next()
nodes := testCase.recorder.GetNodes()

assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.EqualError(t, err, testCase.err.Error())
}
assert.Equal(t, testCase.node, node)
assert.Equal(t, testCase.expectedRecorder, testCase.recorder)
assert.Equal(t, testCase.nodes, nodes)
})
}
}

0 comments on commit ba6cc40

Please sign in to comment.