diff --git a/lib/trie/trie.go b/lib/trie/trie.go index e98d6b27dd..0b81cca93c 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -62,17 +62,37 @@ func (t *Trie) Snapshot() (newTrie *Trie) { } } +func (t *Trie) prepLeafForMutation(currentLeaf *node.Leaf) (newLeaf *node.Leaf) { + if currentLeaf.Generation == t.generation { + // no need to deep copy and update generation + // of current leaf. + newLeaf = currentLeaf + } else { + newNode := updateGeneration(currentLeaf, t.generation, t.deletedKeys) + newLeaf = newNode.(*node.Leaf) + } + newLeaf.SetDirty(true) + return newLeaf +} + +func (t *Trie) prepBranchForMutation(currentBranch *node.Branch) (newBranch *node.Branch) { + if currentBranch.Generation == t.generation { + // no need to deep copy and update generation + // of current branch. + newBranch = currentBranch + } else { + newNode := updateGeneration(currentBranch, t.generation, t.deletedKeys) + newBranch = newNode.(*node.Branch) + } + newBranch.SetDirty(true) + return newBranch +} + // updateGeneration is called when the currentNode is from // an older trie generation (snapshot) so we deep copy the // node and update the generation on the newer copy. func updateGeneration(currentNode Node, trieGeneration uint64, deletedHashes map[common.Hash]struct{}) (newNode Node) { - if currentNode.GetGeneration() == trieGeneration { - panic(fmt.Sprintf( - "current node has the same generation %d as the trie generation, "+ - "make sure the caller properly checks for the node generation to "+ - "be smaller than the trie generation.", trieGeneration)) - } const copyChildren = false newNode = currentNode.Copy(copyChildren) newNode.SetGeneration(trieGeneration) @@ -322,17 +342,13 @@ func (t *Trie) insert(parent Node, key, value []byte) (newParent Node) { } // TODO ensure all values have dirty set to true - newParent = parent - if parent.GetGeneration() < t.generation { - newParent = updateGeneration(parent, t.generation, t.deletedKeys) - } - switch newParent.Type() { + switch parent.Type() { case node.BranchType, node.BranchWithValueType: - parentBranch := newParent.(*node.Branch) + parentBranch := parent.(*node.Branch) return t.insertInBranch(parentBranch, key, value) default: - parentLeaf := newParent.(*node.Leaf) + parentLeaf := parent.(*node.Leaf) return t.insertInLeaf(parentLeaf, key, value) } } @@ -340,11 +356,12 @@ func (t *Trie) insert(parent Node, key, value []byte) (newParent Node) { func (t *Trie) insertInLeaf(parentLeaf *node.Leaf, key, value []byte) (newParent Node) { if bytes.Equal(parentLeaf.Key, key) { - if !bytes.Equal(value, parentLeaf.Value) { - parentLeaf.Value = value - parentLeaf.Generation = t.generation - parentLeaf.SetDirty(true) + if bytes.Equal(value, parentLeaf.Value) { + return parentLeaf } + + parentLeaf = t.prepLeafForMutation(parentLeaf) + parentLeaf.Value = value return parentLeaf } @@ -364,9 +381,9 @@ func (t *Trie) insertInLeaf(parentLeaf *node.Leaf, key, if len(key) < len(parentLeafKey) { // Move the current leaf parent as a child to the new branch. + parentLeaf = t.prepLeafForMutation(parentLeaf) childIndex := parentLeafKey[commonPrefixLength] parentLeaf.Key = parentLeaf.Key[commonPrefixLength+1:] - parentLeaf.SetDirty(true) newBranchParent.Children[childIndex] = parentLeaf } @@ -378,9 +395,9 @@ func (t *Trie) insertInLeaf(parentLeaf *node.Leaf, key, newBranchParent.Value = parentLeaf.Value } else { // make the leaf a child of the new branch + parentLeaf = t.prepLeafForMutation(parentLeaf) childIndex := parentLeafKey[commonPrefixLength] parentLeaf.Key = parentLeaf.Key[commonPrefixLength+1:] - parentLeaf.SetDirty(true) newBranchParent.Children[childIndex] = parentLeaf } childIndex := key[commonPrefixLength] @@ -395,9 +412,9 @@ func (t *Trie) insertInLeaf(parentLeaf *node.Leaf, key, } func (t *Trie) insertInBranch(parentBranch *node.Branch, key, value []byte) (newParent Node) { + parentBranch = t.prepBranchForMutation(parentBranch) + if bytes.Equal(key, parentBranch.Key) { - parentBranch.SetDirty(true) - parentBranch.Generation = t.generation parentBranch.Value = value return parentBranch } @@ -418,12 +435,9 @@ func (t *Trie) insertInBranch(parentBranch *node.Branch, key, value []byte) (new } } else { child = t.insert(child, remainingKey, value) - child.SetDirty(true) } parentBranch.Children[childIndex] = child - parentBranch.SetDirty(true) - parentBranch.Generation = t.generation return parentBranch } @@ -435,13 +449,11 @@ func (t *Trie) insertInBranch(parentBranch *node.Branch, key, value []byte) (new Generation: t.generation, Dirty: true, } - parentBranch.SetDirty(true) oldParentIndex := parentBranch.Key[commonPrefixLength] remainingOldParentKey := parentBranch.Key[commonPrefixLength+1:] parentBranch.Key = remainingOldParentKey - parentBranch.Generation = t.generation newParentBranch.Children[oldParentIndex] = parentBranch if len(key) <= commonPrefixLength { @@ -653,13 +665,8 @@ func (t *Trie) clearPrefixLimit(parent Node, prefix []byte, limit uint32) ( return nil, 0, true } - newParent = parent - if parent.GetGeneration() < t.generation { - newParent = updateGeneration(parent, t.generation, t.deletedKeys) - } - - if newParent.Type() == node.LeafType { - leaf := newParent.(*node.Leaf) + if parent.Type() == node.LeafType { + leaf := parent.(*node.Leaf) // if prefix is not found, it's also all deleted. // TODO check this is the same behaviour as in substrate const allDeleted = true @@ -667,22 +674,11 @@ func (t *Trie) clearPrefixLimit(parent Node, prefix []byte, limit uint32) ( valuesDeleted = 1 return nil, valuesDeleted, allDeleted } - // not modified so return the leaf of the original - // trie generation. The copied leaf newParent will be - // garbage collected. return parent, 0, allDeleted } - branch := newParent.(*node.Branch) - newParent, valuesDeleted, allDeleted = t.clearPrefixLimitBranch(branch, prefix, limit) - if valuesDeleted == 0 { - // not modified so return the node of the original - // trie generation. The copied newParent will be - // garbage collected. - newParent = parent - } - - return newParent, valuesDeleted, allDeleted + branch := parent.(*node.Branch) + return t.clearPrefixLimitBranch(branch, prefix, limit) } func (t *Trie) clearPrefixLimitBranch(branch *node.Branch, prefix []byte, limit uint32) ( @@ -714,13 +710,14 @@ func (t *Trie) clearPrefixLimitBranch(branch *node.Branch, prefix []byte, limit childPrefix := prefix[len(branch.Key)+1:] child := branch.Children[childIndex] - newParent = branch // mostly just a reminder for the reader - branch.Children[childIndex], valuesDeleted, allDeleted = t.clearPrefixLimit(child, childPrefix, limit) - if valuesDeleted > 0 { - branch.SetDirty(true) - newParent = handleDeletion(branch, prefix) + child, valuesDeleted, allDeleted = t.clearPrefixLimit(child, childPrefix, limit) + if valuesDeleted == 0 { + return branch, valuesDeleted, allDeleted } + branch = t.prepBranchForMutation(branch) + branch.Children[childIndex] = child + newParent = handleDeletion(branch, prefix) return newParent, valuesDeleted, allDeleted } @@ -738,11 +735,16 @@ func (t *Trie) clearPrefixLimitChild(branch *node.Branch, prefix []byte, limit u } nilPrefix := ([]byte)(nil) - branch.Children[childIndex], valuesDeleted = t.deleteNodesLimit(child, nilPrefix, limit) - branch.SetDirty(true) + child, valuesDeleted = t.deleteNodesLimit(child, nilPrefix, limit) + if valuesDeleted == 0 { + allDeleted = branch.Children[childIndex] == nil + return branch, valuesDeleted, allDeleted + } - newParent = handleDeletion(branch, prefix) + branch = t.prepBranchForMutation(branch) + branch.Children[childIndex] = child + newParent = handleDeletion(branch, prefix) allDeleted = branch.Children[childIndex] == nil return newParent, valuesDeleted, allDeleted } @@ -757,17 +759,12 @@ func (t *Trie) deleteNodesLimit(parent Node, prefix []byte, limit uint32) ( return nil, 0 } - newParent = parent - if parent.GetGeneration() < t.generation { - newParent = updateGeneration(parent, t.generation, t.deletedKeys) - } - - if newParent.Type() == node.LeafType { + if parent.Type() == node.LeafType { valuesDeleted = 1 return nil, valuesDeleted } - branch := newParent.(*node.Branch) + branch := parent.(*node.Branch) fullKey := concatenateSlices(prefix, branch.Key) @@ -779,6 +776,7 @@ func (t *Trie) deleteNodesLimit(parent Node, prefix []byte, limit uint32) ( continue } + branch = t.prepBranchForMutation(branch) branch.Children[i], newDeleted = t.deleteNodesLimit(child, fullKey, limit) if branch.Children[i] == nil { nilChildren++ @@ -786,7 +784,6 @@ func (t *Trie) deleteNodesLimit(parent Node, prefix []byte, limit uint32) ( limit -= newDeleted valuesDeleted += newDeleted - branch.SetDirty(true) newParent = handleDeletion(branch, fullKey) if nilChildren == node.ChildrenCapacity && branch.Value == nil { @@ -825,23 +822,15 @@ func (t *Trie) clearPrefix(parent Node, prefix []byte) ( return nil, false } - newParent = parent - if parent.GetGeneration() < t.generation { - newParent = updateGeneration(parent, t.generation, t.deletedKeys) - } - - if bytes.HasPrefix(newParent.GetKey(), prefix) { + if bytes.HasPrefix(parent.GetKey(), prefix) { return nil, true } - if newParent.Type() == node.LeafType { - // not modified so return the leaf of the original - // trie generation. The copied newParent will be - // garbage collected. + if parent.Type() == node.LeafType { return parent, false } - branch := newParent.(*node.Branch) + branch := parent.(*node.Branch) if len(prefix) == len(branch.Key)+1 && bytes.HasPrefix(branch.Key, prefix[:len(prefix)-1]) { @@ -850,15 +839,11 @@ func (t *Trie) clearPrefix(parent Node, prefix []byte) ( child := branch.Children[childIndex] if child == nil { - // child is already nil at the child index - // node is not modified so return the branch of the original - // trie generation. The copied newParent will be - // garbage collected. return parent, false } + branch = t.prepBranchForMutation(branch) branch.Children[childIndex] = nil - branch.SetDirty(true) newParent = handleDeletion(branch, prefix) return newParent, true } @@ -866,9 +851,6 @@ func (t *Trie) clearPrefix(parent Node, prefix []byte) ( noPrefixForNode := len(prefix) <= len(branch.Key) || lenCommonPrefix(branch.Key, prefix) < len(branch.Key) if noPrefixForNode { - // not modified so return the branch of the original - // trie generation. The copied newParent will be - // garbage collected. return parent, false } @@ -876,15 +858,13 @@ func (t *Trie) clearPrefix(parent Node, prefix []byte) ( childPrefix := prefix[len(branch.Key)+1:] child := branch.Children[childIndex] - branch.Children[childIndex], updated = t.clearPrefix(child, childPrefix) + child, updated = t.clearPrefix(child, childPrefix) if !updated { - // branch not modified so return the branch of the original - // trie generation. The copied newParent will be - // garbage collected. return parent, false } - branch.SetDirty(true) + branch = t.prepBranchForMutation(branch) + branch.Children[childIndex] = child newParent = handleDeletion(branch, prefix) return newParent, true } @@ -902,32 +882,15 @@ func (t *Trie) delete(parent Node, key []byte) (newParent Node, deleted bool) { return nil, false } - newParent = parent - if parent.GetGeneration() < t.generation { - newParent = updateGeneration(parent, t.generation, t.deletedKeys) - } - - if newParent.Type() == node.LeafType { - newParent = deleteLeaf(newParent, key) - if newParent == nil { + if parent.Type() == node.LeafType { + if deleteLeaf(parent, key) == nil { return nil, true } - // The leaf was not deleted so return the original - // parent without its generation updated. - // The copied newParent will be garbage collected. return parent, false } - branch := newParent.(*node.Branch) - newParent, deleted = t.deleteBranch(branch, key) - if !deleted { - // Nothing was deleted so return the original - // parent without its generation updated. - // The copied newParent will be garbage collected. - return parent, false - } - - return newParent, true + branch := parent.(*node.Branch) + return t.deleteBranch(branch, key) } func deleteLeaf(parent Node, key []byte) (newParent Node) { @@ -939,8 +902,8 @@ func deleteLeaf(parent Node, key []byte) (newParent Node) { func (t *Trie) deleteBranch(branch *node.Branch, key []byte) (newParent Node, deleted bool) { if len(key) == 0 || bytes.Equal(branch.Key, key) { + branch = t.prepBranchForMutation(branch) branch.Value = nil - branch.SetDirty(true) return handleDeletion(branch, key), true } @@ -954,8 +917,8 @@ func (t *Trie) deleteBranch(branch *node.Branch, key []byte) (newParent Node, de return branch, false } + branch = t.prepBranchForMutation(branch) branch.Children[childIndex] = newChild - branch.SetDirty(true) newParent = handleDeletion(branch, key) return newParent, true } diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index eaeb5b1ace..7ee8858c51 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -159,19 +159,6 @@ func Test_Trie_updateGeneration(t *testing.T) { } }) } - - t.Run("panic on same generation", func(t *testing.T) { - t.Parallel() - node := &node.Leaf{Generation: 1} - const trieGenration = 1 - assert.PanicsWithValue(t, - "current node has the same generation 1 as the trie generation, "+ - "make sure the caller properly checks for the node generation to "+ - "be smaller than the trie generation.", - func() { - updateGeneration(node, trieGenration, nil) - }) - }) } func getPointer(x interface{}) (pointer uintptr, ok bool) { @@ -1204,9 +1191,8 @@ func Test_Trie_insert(t *testing.T) { key: []byte{1}, value: []byte("same"), newNode: &node.Leaf{ - Key: []byte{1}, - Value: []byte("same"), - Generation: 1, + Key: []byte{1}, + Value: []byte("same"), }, }, "write leaf as child to parent leaf": { @@ -2518,6 +2504,26 @@ func Test_Trie_clearPrefixLimit(t *testing.T) { valuesDeleted: 2, allDeleted: true, }, + "delete child of branch with limit reached": { + trie: Trie{ + generation: 1, + }, + parent: &node.Branch{ + Key: []byte{1}, + Value: []byte{1}, + Children: [16]node.Node{ + &node.Leaf{Key: []byte{3}}, + }, + }, + prefix: []byte{1, 0}, + newParent: &node.Branch{ + Key: []byte{1}, + Value: []byte{1}, + Children: [16]node.Node{ + &node.Leaf{Key: []byte{3}}, + }, + }, + }, } for name, testCase := range testCases {