Skip to content

Commit

Permalink
Adapt trie code to use Type() before type assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jan 5, 2022
1 parent 7d0e14d commit f986135
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 30 deletions.
4 changes: 3 additions & 1 deletion internal/trie/node/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ func Decode(reader io.Reader) (n Node, err error) {
// find other values using the persistent database.
func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) {
nodeType := Type(header >> 6)
if nodeType != BranchType && nodeType != BranchWithValueType {
switch nodeType {
case BranchType, BranchWithValueType:
default:
return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/trie/node/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ const (
func (b *Branch) encodeHeader(writer io.Writer) (err error) {
var header byte
if b.Value == nil {
header = 2 << 6
header = byte(BranchType) << 6
} else {
header = 3 << 6
header = byte(BranchWithValueType) << 6
}

if len(b.Key) >= keyLenOffset {
Expand All @@ -44,7 +44,7 @@ func (b *Branch) encodeHeader(writer io.Writer) (err error) {

// encodeHeader creates the encoded header for the leaf.
func (l *Leaf) encodeHeader(writer io.Writer) (err error) {
var header byte = 1 << 6
header := byte(LeafType) << 6

if len(l.Key) < 63 {
header |= byte(len(l.Key))
Expand Down
46 changes: 30 additions & 16 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ func (t *Trie) store(db chaindb.Batch, n Node) error {
return err
}

if branch, ok := n.(*node.Branch); ok {
switch n.Type() {
case node.BranchType, node.BranchWithValueType:
branch := n.(*node.Branch)
for _, child := range branch.Children {
if child == nil {
continue
Expand Down Expand Up @@ -111,11 +113,14 @@ func (t *Trie) loadFromProof(rawProof [][]byte, rootHash []byte) error {
// loadProof is a recursive function that will create all the trie paths based
// on the mapped proofs slice starting at the root
func (t *Trie) loadProof(proofHashToNode map[string]Node, n Node) {
branch, ok := n.(*node.Branch)
if !ok {
switch n.Type() {
case node.BranchType, node.BranchWithValueType:
default:
return
}

branch := n.(*node.Branch)

for i, child := range branch.Children {
if child == nil {
continue
Expand Down Expand Up @@ -161,11 +166,14 @@ func (t *Trie) Load(db chaindb.Database, rootHash common.Hash) error {
}

func (t *Trie) load(db chaindb.Database, n Node) error {
branch, ok := n.(*node.Branch)
if !ok {
switch n.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

branch := n.(*node.Branch)

for i, child := range branch.Children {
if child == nil {
continue
Expand Down Expand Up @@ -199,11 +207,14 @@ func (t *Trie) load(db chaindb.Database, n Node) error {
// PopulateNodeHashes writes hashes of each children of the node given
// as keys to the map hashesSet.
func (t *Trie) PopulateNodeHashes(n Node, hashesSet map[common.Hash]struct{}) {
branch, ok := n.(*node.Branch)
if !ok {
switch n.Type() {
case node.BranchType, node.BranchWithValueType:
default:
return
}

branch := n.(*node.Branch)

for _, child := range branch.Children {
if child == nil {
continue
Expand Down Expand Up @@ -363,13 +374,16 @@ func (t *Trie) writeDirty(db chaindb.Batch, n Node) error {
hash, err)
}

branch, ok := n.(*node.Branch)
if !ok {
// the node is a leaf
n.SetDirty(false)
n.SetDirty(false)

switch n.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

branch := n.(*node.Branch)

for _, child := range branch.Children {
if child == nil {
continue
Expand All @@ -382,8 +396,6 @@ func (t *Trie) writeDirty(db chaindb.Batch, n Node) error {
}
}

branch.SetDirty(false)

return nil
}

Expand Down Expand Up @@ -420,12 +432,14 @@ func (t *Trie) getInsertedNodeHashes(n Node) (hashes []common.Hash, err error) {

hashes = append(hashes, common.BytesToHash(hash))

branch, ok := n.(*node.Branch)
if !ok {
// node is a leaf
switch n.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return hashes, nil
}

branch := n.(*node.Branch)

for _, child := range branch.Children {
if child == nil {
continue
Expand Down
7 changes: 5 additions & 2 deletions lib/trie/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ func find(parent Node, key []byte, recorder recorder) error {

recorder.Record(hash, enc)

b, ok := parent.(*node.Branch)
if !ok {
switch parent.Type() {
case node.BranchType, node.BranchWithValueType:
default: // not a branch
return nil
}

b := parent.(*node.Branch)

length := lenCommonPrefix(b.Key, key)

// found the value at this node
Expand Down
21 changes: 13 additions & 8 deletions lib/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,23 @@ func (t *Trie) tryPut(key, value []byte) {

// insert attempts to insert a key with value into the trie
func (t *Trie) insert(parent Node, key []byte, value Node) Node {
switch p := t.maybeUpdateGeneration(parent).(type) {
case *node.Branch:
newParent := t.maybeUpdateGeneration(parent)
if newParent == nil {
value.SetKey(key)
return value
}

switch newParent.Type() {
case node.BranchType, node.BranchWithValueType:
p := newParent.(*node.Branch)
n := t.updateBranch(p, key, value)

if p != nil && n != nil && n.IsDirty() {
p.SetDirty(true)
}
return n
case nil:
value.SetKey(key)
return value
case *node.Leaf:
case node.LeafType:
p := newParent.(*node.Leaf)
// if a value already exists in the trie at this key, overwrite it with the new value
// if the values are the same, don't mark node dirty
if p.Value != nil && bytes.Equal(p.Key, key) {
Expand Down Expand Up @@ -324,9 +329,9 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node {
}

return br
default:
panic("unknown node type: " + fmt.Sprint(newParent.Type()))
}
// This will never happen.
return nil
}

// updateBranch attempts to add the value node to a branch
Expand Down

0 comments on commit f986135

Please sign in to comment.