diff --git a/lib/trie/proof.go b/lib/trie/proof.go index eccae4ce11..094fe48f7e 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -47,13 +47,7 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e return nil, err } - for { - recNode, err := recorder.Next() - if errors.Is(err, record.ErrNoNextNode) { - break - } else if err != nil { - return nil, fmt.Errorf("recorder failed for key 0x%x: %w", k, err) - } + for _, recNode := range recorder.GetNodes() { nodeHashHex := common.BytesToHex(recNode.Hash) if _, ok := trackedProofs[nodeHashHex]; !ok { trackedProofs[nodeHashHex] = recNode.RawData diff --git a/lib/trie/record/recorder.go b/lib/trie/record/recorder.go index 4281bd4dd2..130b434338 100644 --- a/lib/trie/record/recorder.go +++ b/lib/trie/record/recorder.go @@ -3,12 +3,6 @@ package record -import "errors" - -var ( - ErrNoNextNode = errors.New("no next node") -) - // Recorder records the list of nodes found by Lookup.Find type Recorder struct { nodes []Node @@ -24,15 +18,10 @@ func (r *Recorder) Record(hash, rawData []byte) { r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash}) } -// Next returns the first node in the recorded list -// and removes it (shift operation). -func (r *Recorder) Next() (node Node, err error) { - if len(r.nodes) == 0 { - return node, ErrNoNextNode - } - - node = r.nodes[0] - r.nodes = r.nodes[1:] - - return node, nil +// GetNodes returns all the nodes recorded. +// Note it does not copy its slice of nodes. +// It's fine to not copy them since the recorder +// is not used again after a call to GetNodes() +func (r *Recorder) GetNodes() (nodes []Node) { + return r.nodes } diff --git a/lib/trie/record/recorder_test.go b/lib/trie/record/recorder_test.go index 624889c545..638661b97a 100644 --- a/lib/trie/record/recorder_test.go +++ b/lib/trie/record/recorder_test.go @@ -70,17 +70,13 @@ func Test_Recorder_Record(t *testing.T) { } } -func Test_Recorder_Next(t *testing.T) { +func Test_Recorder_GetNodes(t *testing.T) { testCases := map[string]struct { - recorder *Recorder - node Node - err error - expectedRecorder *Recorder + recorder *Recorder + nodes []Node }{ "no node": { - recorder: &Recorder{}, - err: ErrNoNextNode, - expectedRecorder: &Recorder{}, + recorder: &Recorder{}, }, "get single node from recorder": { recorder: &Recorder{ @@ -88,10 +84,7 @@ func Test_Recorder_Next(t *testing.T) { {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, }, }, - node: Node{Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - expectedRecorder: &Recorder{ - nodes: []Node{}, - }, + nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}}, }, "get node from multiple nodes in recorder": { recorder: &Recorder{ @@ -101,12 +94,10 @@ func Test_Recorder_Next(t *testing.T) { {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, }, }, - node: Node{Hash: []byte{1, 2}, RawData: []byte{3, 4}}, - expectedRecorder: &Recorder{ - nodes: []Node{ - {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, - {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, - }, + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, }, }, } @@ -116,14 +107,9 @@ func Test_Recorder_Next(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - node, err := testCase.recorder.Next() + nodes := testCase.recorder.GetNodes() - assert.ErrorIs(t, err, testCase.err) - if testCase.err != nil { - assert.EqualError(t, err, testCase.err.Error()) - } - assert.Equal(t, testCase.node, node) - assert.Equal(t, testCase.expectedRecorder, testCase.recorder) + assert.Equal(t, testCase.nodes, nodes) }) } }