From dbc59e508382367f9eaa93366444df82533859ab Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 10 Dec 2021 13:13:12 +0000 Subject: [PATCH] Refactor lib/trie/database.go - Explicit variable names - Less indentation - More return early - Add named returns - Clarify comments - Wrap errors --- dot/state/offline_pruner.go | 5 +- lib/trie/database.go | 408 ++++++++++++++++++++---------------- lib/trie/proof.go | 2 +- 3 files changed, 232 insertions(+), 183 deletions(-) diff --git a/dot/state/offline_pruner.go b/dot/state/offline_pruner.go index 3883b741fa0..9e42b976dff 100644 --- a/dot/state/offline_pruner.go +++ b/dot/state/offline_pruner.go @@ -117,10 +117,7 @@ func (p *OfflinePruner) SetBloomFilter() (err error) { return err } - err = tr.GetNodeHashes(tr.RootNode(), keys) - if err != nil { - return err - } + tr.GetNodeHashes(tr.RootNode(), keys) // get parent header of current block header, err = p.blockState.GetHeader(header.ParentHash) diff --git a/lib/trie/database.go b/lib/trie/database.go index 362720c5ce4..1752a46753c 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -33,23 +33,23 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, curr Node) error { - if curr == nil { +func (t *Trie) store(db chaindb.Batch, n Node) error { + if n == nil { return nil } - enc, hash, err := curr.EncodeAndHash() + encoding, hash, err := n.EncodeAndHash() if err != nil { return err } - err = db.Put(hash, enc) + err = db.Put(hash, encoding) if err != nil { return err } - if c, ok := curr.(*node.Branch); ok { - for _, child := range c.Children { + if branch, ok := n.(*node.Branch); ok { + for _, child := range branch.Children { if child == nil { continue } @@ -61,241 +61,270 @@ func (t *Trie) store(db chaindb.Batch, curr Node) error { } } - if curr.IsDirty() { - curr.SetDirty(false) + if n.IsDirty() { + n.SetDirty(false) } return nil } -// LoadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik. -func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { - if len(proof) == 0 { +var ( + ErrDecodeNode = errors.New("cannot decode node") +) + +// loadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik. +func (t *Trie) loadFromProof(rawProof [][]byte, rootHash []byte) error { + if len(rawProof) == 0 { return ErrEmptyProof } - mappedNodes := make(map[string]Node, len(proof)) + proofHashToNode := make(map[string]Node, len(rawProof)) - // map all the proofs hash -> decoded node - // and takes the loop to indentify the root node - for _, rawNode := range proof { - decNode, err := node.Decode(bytes.NewReader(rawNode)) + for i, rawNode := range rawProof { + decodedNode, err := node.Decode(bytes.NewReader(rawNode)) if err != nil { - return err + return fmt.Errorf("%w: at index %d: 0x%x", + ErrDecodeNode, i, rawNode) } - decNode.SetDirty(false) - decNode.SetEncodingAndHash(rawNode, nil) + const dirty = false + decodedNode.SetDirty(dirty) + decodedNode.SetEncodingAndHash(rawNode, nil) - _, computedRoot, err := decNode.EncodeAndHash() + _, hash, err := decodedNode.EncodeAndHash() if err != nil { - return err + return fmt.Errorf("cannot encode and hash node at index %d: %w", i, err) } - mappedNodes[common.BytesToHex(computedRoot)] = decNode + proofHash := common.BytesToHex(hash) + proofHashToNode[proofHash] = decodedNode - if bytes.Equal(computedRoot, root) { - t.root = decNode + if bytes.Equal(hash, rootHash) { + // Found root in proof + t.root = decodedNode } } - t.loadProof(mappedNodes, t.root) + t.loadProof(proofHashToNode, t.root) + return nil } // loadProof is a recursive function that will create all the trie paths based -// on the mapped proofs slice starting by the root -func (t *Trie) loadProof(proof map[string]Node, curr Node) { - c, ok := curr.(*node.Branch) +// on the mapped proofs slice starting at the root +func (t *Trie) loadProof(proofHashToNode map[string]Node, n Node) { + branch, ok := n.(*node.Branch) if !ok { return } - for i, child := range c.Children { + for i, child := range branch.Children { if child == nil { continue } - proofNode, ok := proof[common.BytesToHex(child.GetHash())] + proofHash := common.BytesToHex(child.GetHash()) + node, ok := proofHashToNode[proofHash] if !ok { continue } + delete(proofHashToNode, proofHash) - c.Children[i] = proofNode - t.loadProof(proof, proofNode) + branch.Children[i] = node + t.loadProof(proofHashToNode, node) } } // Load reconstructs the trie from the database from the given root hash. // It is used when restarting the node to load the current state trie. -func (t *Trie) Load(db chaindb.Database, root common.Hash) error { - if root == EmptyHash { +func (t *Trie) Load(db chaindb.Database, rootHash common.Hash) error { + if rootHash == EmptyHash { t.root = nil return nil } - enc, err := db.Get(root[:]) + rootHashBytes := rootHash[:] + + encodedNode, err := db.Get(rootHashBytes) if err != nil { - return fmt.Errorf("failed to find root key=%s: %w", root, err) + return fmt.Errorf("failed to find root key %s: %w", rootHash, err) } - t.root, err = node.Decode(bytes.NewReader(enc)) + reader := bytes.NewReader(encodedNode) + root, err := node.Decode(reader) if err != nil { - return err + return fmt.Errorf("cannot decode root node: %w", err) } - + t.root = root t.root.SetDirty(false) - t.root.SetEncodingAndHash(enc, root[:]) + t.root.SetEncodingAndHash(encodedNode, rootHashBytes) return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, curr Node) error { - if c, ok := curr.(*node.Branch); ok { - for i, child := range c.Children { - if child == nil { - continue - } +func (t *Trie) load(db chaindb.Database, n Node) error { + branch, ok := n.(*node.Branch) + if !ok { + return nil + } - hash := child.GetHash() - enc, err := db.Get(hash) - if err != nil { - return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) - } + for i, child := range branch.Children { + if child == nil { + continue + } - child, err = node.Decode(bytes.NewReader(enc)) - if err != nil { - return err - } + hash := child.GetHash() + encodedNode, err := db.Get(hash) + if err != nil { + return fmt.Errorf("cannot find child node key 0x%x in database: %w", hash, err) + } - child.SetDirty(false) - child.SetEncodingAndHash(enc, hash) + reader := bytes.NewReader(encodedNode) + decodedNode, err := node.Decode(reader) + if err != nil { + return fmt.Errorf("cannot decode node with hash 0x%x: %w", hash, err) + } - c.Children[i] = child - err = t.load(db, child) - if err != nil { - return err - } + decodedNode.SetDirty(false) + decodedNode.SetEncodingAndHash(encodedNode, hash) + branch.Children[i] = decodedNode + + err = t.load(db, decodedNode) + if err != nil { + return fmt.Errorf("cannot load child with hash 0x%x: %w", hash, err) } } return nil } -// GetNodeHashes return hash of each key of the trie. -func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { - if c, ok := curr.(*node.Branch); ok { - for _, child := range c.Children { - if child == nil { - continue - } - - hash := child.GetHash() - keys[common.BytesToHash(hash)] = struct{}{} +// GetNodeHashes writes hashes of each children of the node given +// as keys to the map hashesSet. +func (t *Trie) GetNodeHashes(n Node, hashesSet map[common.Hash]struct{}) { + branch, ok := n.(*node.Branch) + if !ok { + return + } - err := t.GetNodeHashes(child, keys) - if err != nil { - return err - } + for _, child := range branch.Children { + if child == nil { + continue } + + hash := common.BytesToHash(child.GetHash()) + hashesSet[hash] = struct{}{} + + t.GetNodeHashes(child, hashesSet) } - return nil } -// PutInDB puts a value into the trie and writes the updates nodes the database. -// Since it needs to write all the nodes from the changed node up to the root, -// it writes these in a batch operation. +// PutInDB inserts a value in the trie at the key given. +// It writes the updated nodes from the changed node up to the root node +// to the database in a batch operation. func (t *Trie) PutInDB(db chaindb.Database, key, value []byte) error { t.Put(key, value) return t.WriteDirty(db) } -// DeleteFromDB deletes a value from the trie and writes the updated nodes the database. -// Since it needs to write all the nodes from the changed node up to the root, -// it writes these in a batch operation. +// DeleteFromDB deletes a value from the trie at the key given. +// It writes the updated nodes from the changed node up to the root node +// to the database in a batch operation. func (t *Trie) DeleteFromDB(db chaindb.Database, key []byte) error { t.Delete(key) return t.WriteDirty(db) } -// ClearPrefixFromDB deletes all keys with the given prefix from the trie -// and writes the updated nodes the database. Since it needs to write all -// the nodes from the changed node up to the root, it writes these +// ClearPrefixFromDB deletes all nodes with keys starting the given prefix +// from the trie. It writes the updated nodes from the changed node up to +// the root node to the database in a batch operation. // in a batch operation. func (t *Trie) ClearPrefixFromDB(db chaindb.Database, prefix []byte) error { t.ClearPrefix(prefix) return t.WriteDirty(db) } -// GetFromDB retrieves a value from the trie using the database. +// GetFromDB retrieves a value at the given key from the trie using the database. // It recursively descends into the trie using the database starting // from the root node until it reaches the node with the given key. // It then reads the value from the database. -func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error) { - if root == EmptyHash { +func GetFromDB(db chaindb.Database, rootHash common.Hash, key []byte) ( + value []byte, err error) { + if rootHash == EmptyHash { return nil, nil } k := codec.KeyLEToNibbles(key) - enc, err := db.Get(root[:]) + encodedRootNode, err := db.Get(rootHash[:]) if err != nil { - return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) + return nil, fmt.Errorf("cannot find root hash key 0x%x: %w", rootHash, err) } - rootNode, err := node.Decode(bytes.NewReader(enc)) + reader := bytes.NewReader(encodedRootNode) + rootNode, err := node.Decode(reader) if err != nil { - return nil, err + return nil, fmt.Errorf("cannot decode root node: %w", err) } return getFromDB(db, rootNode, k) } -func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { - var value []byte - - switch p := parent.(type) { - case *node.Branch: - length := lenCommonPrefix(p.Key, key) - - // found the value at this node - if bytes.Equal(p.Key, key) || len(key) == 0 { - return p.Value, nil - } - - // did not find value - if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { - return nil, nil - } - - if p.Children[key[length]] == nil { - return nil, nil +// getFromDB recursively searches through the trie and database +// for the value corresponding to a key. +// Note it does not copy the value so modifying the value bytes +// slice will modify the value of the node in the trie. +func getFromDB(db chaindb.Database, n Node, key []byte) ( + value []byte, err error) { + // if parent == nil { + // return nil, nil + // } + leaf, ok := n.(*node.Leaf) + if ok { + if bytes.Equal(leaf.Key, key) { + return leaf.Value, nil } + return nil, nil + } - // load child with potential value - enc, err := db.Get(p.Children[key[length]].(*node.Leaf).Hash) - if err != nil { - return nil, fmt.Errorf("failed to find node in database: %w", err) - } + branch := n.(*node.Branch) + // Key is equal to the key of this branch or is empty + if len(key) == 0 || bytes.Equal(branch.Key, key) { + return branch.Value, nil + } - child, err := node.Decode(bytes.NewReader(enc)) - if err != nil { - return nil, err - } + commonPrefixLength := lenCommonPrefix(branch.Key, key) + if len(key) < len(branch.Key) && bytes.Equal(branch.Key[:commonPrefixLength], key) { + // The key to search is a prefix of the node key and is smaller than the node key. + // Example: key to search: 0xabcd + // branch key: 0xabcdef + return nil, nil + } - value, err = getFromDB(db, child, key[length+1:]) - if err != nil { - return nil, err - } - case *node.Leaf: - if bytes.Equal(p.Key, key) { - return p.Value, nil - } - case nil: + // childIndex is the nibble after the common prefix length in the key being searched. + childIndex := key[commonPrefixLength] + childWithHashOnly := branch.Children[childIndex] + if childWithHashOnly == nil { return nil, nil + } + childHash := childWithHashOnly.GetHash() + encodedChild, err := db.Get(childHash) + if err != nil { + return nil, fmt.Errorf( + "cannot find child with hash 0x%x in database: %w", + childHash, err) } - return value, nil + + reader := bytes.NewReader(encodedChild) + decodedChild, err := node.Decode(reader) + if err != nil { + return nil, fmt.Errorf( + "cannot decode child node with hash 0x%x: %w", + childHash, err) + } + + return getFromDB(db, decodedChild, key[commonPrefixLength+1:]) + // Note: do not wrap error since it's called recursively. } // WriteDirty writes all dirty nodes to the database and sets them to clean @@ -310,94 +339,117 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { - if curr == nil || !curr.IsDirty() { +func (t *Trie) writeDirty(db chaindb.Batch, n Node) error { + if n == nil || !n.IsDirty() { return nil } - enc, hash, err := curr.EncodeAndHash() + encoding, hash, err := n.EncodeAndHash() if err != nil { - return err + return fmt.Errorf( + "cannot encode and hash node with hash 0x%x: %w", + n.GetHash(), err) } - // always hash root even if encoding is under 32 bytes - if curr == t.root { - h, err := common.Blake2bHash(enc) + if n == t.root { + // hash root node even if its encoding is under 32 bytes + encodingDigest, err := common.Blake2bHash(encoding) if err != nil { - return err + return fmt.Errorf("cannot hash root node encoding: %w", err) } - hash = h[:] + hash = encodingDigest[:] } - err = db.Put(hash, enc) + err = db.Put(hash, encoding) if err != nil { - return err + return fmt.Errorf( + "cannot put encoding of node with hash 0x%x in database: %w", + hash, err) } - if c, ok := curr.(*node.Branch); ok { - for _, child := range c.Children { - if child == nil { - continue - } + branch, ok := n.(*node.Branch) + if !ok { + // the node is a leaf + n.SetDirty(false) + return nil + } - err = t.writeDirty(db, child) - if err != nil { - return err - } + for _, child := range branch.Children { + if child == nil { + continue + } + + err = t.writeDirty(db, child) + if err != nil { + // Note: do not wrap error since it's returned recursively. + return err } } - curr.SetDirty(false) + branch.SetDirty(false) + return nil } -// GetInsertedNodeHashes returns the hash of nodes that are inserted into state trie since last snapshot is called -// Since inserted nodes are newly created we need to compute their hash values. -func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { +// GetInsertedNodeHashes returns the hashes of all nodes that were +// inserted in the state trie since the last snapshot. +// We need to compute the hash values of each newly inserted node. +func (t *Trie) GetInsertedNodeHashes() (hashes []common.Hash, err error) { return t.getInsertedNodeHashes(t.root) } -func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { - var nodeHashes []common.Hash - if curr == nil || !curr.IsDirty() { +func (t *Trie) getInsertedNodeHashes(n Node) (hashes []common.Hash, err error) { + // TODO pass map of hashes or slice as argument to avoid copying + // and using more memory. + if n == nil || !n.IsDirty() { return nil, nil } - enc, hash, err := curr.EncodeAndHash() + encoding, hash, err := n.EncodeAndHash() if err != nil { - return nil, err + return nil, fmt.Errorf( + "cannot encode and hash node with hash 0x%x: %w", + n.GetHash(), err) } - if curr == t.root && len(enc) < 32 { - h, err := common.Blake2bHash(enc) + if n == t.root && len(encoding) < 32 { + // hash root node even if its encoding is under 32 bytes + encodingDigest, err := common.Blake2bHash(encoding) if err != nil { - return nil, err + return nil, fmt.Errorf("cannot hash root node encoding: %w", err) } - hash = h[:] + hash = encodingDigest[:] } - nodeHash := common.BytesToHash(hash) - nodeHashes = append(nodeHashes, nodeHash) + hashes = append(hashes, common.BytesToHash(hash)) - if c, ok := curr.(*node.Branch); ok { - for _, child := range c.Children { - if child == nil { - continue - } - nodes, err := t.getInsertedNodeHashes(child) - if err != nil { - return nil, err - } - nodeHashes = append(nodeHashes, nodes...) + branch, ok := n.(*node.Branch) + if !ok { + // node is a leaf + return hashes, nil + } + + for _, child := range branch.Children { + if child == nil { + continue + } + + deeperHashes, err := t.getInsertedNodeHashes(child) + if err != nil { + // Note: do not wrap error since this is called recursively. + return nil, err } + + hashes = append(hashes, deeperHashes...) } - return nodeHashes, nil + return hashes, nil } -// GetDeletedNodeHash returns the hash of nodes that are deleted from state trie since last snapshot is called +// GetDeletedNodeHash returns the hash of nodes that were +// deleted from the trie since the last snapshot was made. func (t *Trie) GetDeletedNodeHash() []common.Hash { return t.deletedKeys } diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 2d8444d2db4..4bd2b0c066c 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -84,7 +84,7 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { } proofTrie := NewEmptyTrie() - if err := proofTrie.LoadFromProof(proof, root); err != nil { + if err := proofTrie.loadFromProof(proof, root); err != nil { return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) }