Skip to content

Commit

Permalink
Adapt trie code to use Type() before type assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Dec 14, 2021
1 parent 5655da3 commit ed0a1d9
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 73 deletions.
4 changes: 3 additions & 1 deletion internal/trie/node/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func Decode(reader io.Reader) (n Node, err error) {
// find other values using the persistent database.
func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) {
nodeType := Type(header >> 6)
if nodeType != BranchType && nodeType != BranchWithValueType {
switch nodeType {
case BranchType, BranchWithValueType:
default:
return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/trie/node/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
func (b *Branch) encodeHeader(writer io.Writer) (err error) {
var header byte
if b.Value == nil {
header = 2 << 6
header = byte(BranchType) << 6
} else {
header = 3 << 6
header = byte(BranchWithValueType) << 6
}

if len(b.Key) >= 63 {
Expand All @@ -40,7 +40,7 @@ func (b *Branch) encodeHeader(writer io.Writer) (err error) {

// encodeHeader creates the encoded header for the leaf.
func (l *Leaf) encodeHeader(writer io.Writer) (err error) {
var header byte = 1 << 6
header := byte(LeafType) << 6

if len(l.Key) < 63 {
header = header | byte(len(l.Key))
Expand Down
146 changes: 87 additions & 59 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ func (t *Trie) store(db chaindb.Batch, curr Node) error {
return err
}

if c, ok := curr.(*node.Branch); ok {
for _, child := range c.Children {
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
branch := curr.(*node.Branch)
for _, child := range branch.Children {
if child == nil {
continue
}
Expand Down Expand Up @@ -105,12 +107,15 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error {

// loadProof is a recursive function that will create all the trie paths based
// on the mapped proofs slice starting by the root
func (t *Trie) loadProof(proof map[string]Node, curr Node) {
c, ok := curr.(*node.Branch)
if !ok {
func (t *Trie) loadProof(proof map[string]Node, curr node.Node) {
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
default:
return
}

c := curr.(*node.Branch)

for i, child := range c.Children {
if child == nil {
continue
Expand Down Expand Up @@ -150,55 +155,67 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error {
return t.load(db, t.root)
}

func (t *Trie) load(db chaindb.Database, curr Node) error {
if c, ok := curr.(*node.Branch); ok {
for i, child := range c.Children {
if child == nil {
continue
}
func (t *Trie) load(db chaindb.Database, curr node.Node) error {
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

hash := child.GetHash()
enc, err := db.Get(hash)
if err != nil {
return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err)
}
c := curr.(*node.Branch)

child, err = node.Decode(bytes.NewReader(enc))
if err != nil {
return err
}
for i, child := range c.Children {
if child == nil {
continue
}

child.SetDirty(false)
child.SetEncodingAndHash(enc, hash)
hash := child.GetHash()
enc, err := db.Get(hash)
if err != nil {
return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err)
}

c.Children[i] = child
err = t.load(db, child)
if err != nil {
return err
}
child, err = node.Decode(bytes.NewBuffer(enc))
if err != nil {
return err
}

child.SetDirty(false)
child.SetEncodingAndHash(enc, hash)

c.Children[i] = child
err = t.load(db, child)
if err != nil {
return err
}
}

return nil
}

// GetNodeHashes return hash of each key of the trie.
func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error {
if c, ok := curr.(*node.Branch); ok {
for _, child := range c.Children {
if child == nil {
continue
}
func (t *Trie) GetNodeHashes(curr node.Node, keys map[common.Hash]struct{}) error {
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
default:
return nil
}

hash := child.GetHash()
keys[common.BytesToHash(hash)] = struct{}{}
c := curr.(*node.Branch)
for _, child := range c.Children {
if child == nil {
continue
}

err := t.GetNodeHashes(child, keys)
if err != nil {
return err
}
hash := child.GetHash()
keys[common.BytesToHash(hash)] = struct{}{}

err := t.GetNodeHashes(child, keys)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -335,20 +352,26 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error {
return err
}

if c, ok := curr.(*node.Branch); ok {
for _, child := range c.Children {
if child == nil {
continue
}
curr.SetDirty(false)

err = t.writeDirty(db, child)
if err != nil {
return err
}
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

c := curr.(*node.Branch)
for _, child := range c.Children {
if child == nil {
continue
}

err = t.writeDirty(db, child)
if err != nil {
return err
}
}

curr.SetDirty(false)
return nil
}

Expand Down Expand Up @@ -381,17 +404,22 @@ func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) {
nodeHash := common.BytesToHash(hash)
nodeHashes = append(nodeHashes, nodeHash)

if c, ok := curr.(*node.Branch); ok {
for _, child := range c.Children {
if child == nil {
continue
}
nodes, err := t.getInsertedNodeHashes(child)
if err != nil {
return nil, err
}
nodeHashes = append(nodeHashes, nodes...)
switch curr.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nodeHashes, nil
}

c := curr.(*node.Branch)
for _, child := range c.Children {
if child == nil {
continue
}
nodes, err := t.getInsertedNodeHashes(child)
if err != nil {
return nil, err
}
nodeHashes = append(nodeHashes, nodes...)
}

return nodeHashes, nil
Expand Down
7 changes: 5 additions & 2 deletions lib/trie/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ func find(parent Node, key []byte, recorder recorder) error {

recorder.Record(hash, enc)

b, ok := parent.(*node.Branch)
if !ok {
switch parent.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

b := parent.(*node.Branch)

length := lenCommonPrefix(b.Key, key)

// found the value at this node
Expand Down
21 changes: 13 additions & 8 deletions lib/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,23 @@ func (t *Trie) tryPut(key, value []byte) {

// insert attempts to insert a key with value into the trie
func (t *Trie) insert(parent Node, key []byte, value Node) Node {
switch p := t.maybeUpdateGeneration(parent).(type) {
case *node.Branch:
newParent := t.maybeUpdateGeneration(parent)
if newParent == nil {
value.SetKey(key)
return value
}

switch newParent.Type() {
case node.BranchType, node.BranchWithValueType:
p := newParent.(*node.Branch)
n := t.updateBranch(p, key, value)

if p != nil && n != nil && n.IsDirty() {
p.SetDirty(true)
}
return n
case nil:
value.SetKey(key)
return value
case *node.Leaf:
case node.LeafType:
p := newParent.(*node.Leaf)
// if a value already exists in the trie at this key, overwrite it with the new value
// if the values are the same, don't mark node dirty
if p.Value != nil && bytes.Equal(p.Key, key) {
Expand Down Expand Up @@ -324,9 +329,9 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node {
}

return br
default:
panic("unknown node type: " + fmt.Sprint(newParent.Type()))
}
// This will never happen.
return nil
}

// updateBranch attempts to add the value node to a branch
Expand Down

0 comments on commit ed0a1d9

Please sign in to comment.