Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ProofNode to use interface #2176

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 106 additions & 92 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,59 @@
"github.com/NethermindEth/juno/core/felt"
)

// https://github.com/starknet-io/starknet-p2p-specs/blob/main/p2p/proto/snapshot.proto#L6
type ProofNode struct {
Binary *Binary
Edge *Edge
var (
ErrUnknownProofNode = errors.New("unknown proof node")
ErrChildHashNotFound = errors.New("can't determine the child hash from the parent and child")
)

type ProofNode interface {
Hash(hash hashFunc) *felt.Felt
Len() uint8
PrettyPrint()
}

// Note: does not work for leaves
func (pn *ProofNode) Hash(hash hashFunc) *felt.Felt {
switch {
case pn.Binary != nil:
return hash(pn.Binary.LeftHash, pn.Binary.RightHash)
case pn.Edge != nil:
length := make([]byte, len(pn.Edge.Path.bitset))
length[len(pn.Edge.Path.bitset)-1] = pn.Edge.Path.len
pathFelt := pn.Edge.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(hash(pn.Edge.Child, &pathFelt), lengthFelt)
default:
return nil
}
type Binary struct {
LeftHash *felt.Felt
RightHash *felt.Felt
}

func (pn *ProofNode) Len() uint8 {
if pn.Binary != nil {
return 1
}
return pn.Edge.Path.len
func (b *Binary) Hash(hash hashFunc) *felt.Felt {
return hash(b.LeftHash, b.RightHash)
}

func (pn *ProofNode) PrettyPrint() {
if pn.Binary != nil {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", pn.Binary.LeftHash)
fmt.Printf(" RightHash: %v\n", pn.Binary.RightHash)
}
if pn.Edge != nil {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", pn.Edge.Child)
fmt.Printf(" Path: %v\n", pn.Edge.Path)
}
func (b *Binary) Len() uint8 {
return 1
}

type Binary struct {
LeftHash *felt.Felt
RightHash *felt.Felt
func (b *Binary) PrettyPrint() {
fmt.Printf(" Binary:\n")
fmt.Printf(" LeftHash: %v\n", b.LeftHash)
fmt.Printf(" RightHash: %v\n", b.RightHash)
}

type Edge struct {
Child *felt.Felt // child hash
Path *Key // path from parent to child
}

func (e *Edge) Hash(hash hashFunc) *felt.Felt {
length := make([]byte, len(e.Path.bitset))
length[len(e.Path.bitset)-1] = e.Path.len
pathFelt := e.Path.Felt()
lengthFelt := new(felt.Felt).SetBytes(length)
return new(felt.Felt).Add(hash(e.Child, &pathFelt), lengthFelt)
}

func (e *Edge) Len() uint8 {
return e.Path.Len()
}

func (e *Edge) PrettyPrint() {
fmt.Printf(" Edge:\n")
fmt.Printf(" Child: %v\n", e.Child)
fmt.Printf(" Path: %v\n", e.Path)
}

func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) {
proofs := [2][]ProofNode{}
leftProof, err := GetProof(leftBoundary, tri)
Expand Down Expand Up @@ -110,19 +111,19 @@
rightHash := rNode.Value
if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) {
edgePath := path(sNode.node.Right, sNode.key)
rEdge := ProofNode{Edge: &Edge{
rEdge := &Edge{
Path: &edgePath,
Child: rNode.Value,
}}
}
rightHash = rEdge.Hash(tri.hash)
}
leftHash := lNode.Value
if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) {
edgePath := path(sNode.node.Left, sNode.key)
lEdge := ProofNode{Edge: &Edge{
lEdge := &Edge{
Path: &edgePath,
Child: lNode.Value,
}}
}
leftHash = lEdge.Hash(tri.hash)
}
binary := &Binary{
Expand All @@ -139,19 +140,20 @@
func pathSplitOccurredCheck(mergedPath []ProofNode, nodeHashes map[felt.Felt]ProofNode) error {
splitHappened := false
for _, node := range mergedPath {
if node.Edge != nil {
switch node := node.(type) {
case *Edge:
continue
}

_, leftExists := nodeHashes[*node.Binary.LeftHash]
_, rightExists := nodeHashes[*node.Binary.RightHash]

if leftExists && rightExists {
if splitHappened {
return errors.New("split happened more than once")
case *Binary:
_, leftExists := nodeHashes[*node.LeftHash]
_, rightExists := nodeHashes[*node.RightHash]
if leftExists && rightExists {
if splitHappened {
return errors.New("split happened more than once")
}
splitHappened = true
}

splitHappened = true
default:
return fmt.Errorf("%w: %T", ErrUnknownProofNode, node)

Check warning on line 156 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L155-L156

Added lines #L155 - L156 were not covered by tests
}
}
return nil
Expand All @@ -173,9 +175,10 @@
func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Felt]ProofNode) {
*path = append(*path, currNode)

if currNode.Binary != nil {
nodeLeft, leftExist := nodeHashes[*currNode.Binary.LeftHash]
nodeRight, rightExist := nodeHashes[*currNode.Binary.RightHash]
switch currNode := currNode.(type) {
case *Binary:
nodeLeft, leftExist := nodeHashes[*currNode.LeftHash]
nodeRight, rightExist := nodeHashes[*currNode.RightHash]

if leftExist && rightExist {
return
Expand All @@ -184,8 +187,8 @@
} else if rightExist {
traverseNodes(nodeRight, path, nodeHashes)
}
} else if currNode.Edge != nil {
edgeNode, exist := nodeHashes[*currNode.Edge.Child]
case *Edge:
edgeNode, exist := nodeHashes[*currNode.Child]
if exist {
traverseNodes(edgeNode, path, nodeHashes)
}
Expand Down Expand Up @@ -269,8 +272,8 @@

currNode = commonPath[len(commonPath)-1]

leftNode := nodeHashes[*currNode.Binary.LeftHash]
rightNode := nodeHashes[*currNode.Binary.RightHash]
leftNode := nodeHashes[*currNode.(*Binary).LeftHash]
rightNode := nodeHashes[*currNode.(*Binary).RightHash]

traverseNodes(leftNode, &leftPath, nodeHashes)
traverseNodes(rightNode, &rightPath, nodeHashes)
Expand Down Expand Up @@ -298,11 +301,11 @@
isLeaf := sNode.key.len == tri.height

if sNodeEdge != nil && !isLeaf { // Internal Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}, {Binary: sNodeBinary}}...)
proofNodes = append(proofNodes, sNodeEdge, sNodeBinary)
} else if sNodeEdge == nil && !isLeaf { // Internal Binary
proofNodes = append(proofNodes, []ProofNode{{Binary: sNodeBinary}}...)
proofNodes = append(proofNodes, sNodeBinary)
} else if sNodeEdge != nil && isLeaf { // Leaf Edge
proofNodes = append(proofNodes, []ProofNode{{Edge: sNodeEdge}}...)
proofNodes = append(proofNodes, sNodeEdge)
} else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf
break
}
Expand All @@ -321,16 +324,16 @@
return false
}

switch {
case proofNode.Binary != nil:
switch proofNode := proofNode.(type) {
case *Binary:
if remainingPath.Test(remainingPath.Len() - 1) {
expectedHash = proofNode.Binary.RightHash
expectedHash = proofNode.RightHash
} else {
expectedHash = proofNode.Binary.LeftHash
expectedHash = proofNode.LeftHash
}
remainingPath.RemoveLastBit()
case proofNode.Edge != nil:
subKey, err := remainingPath.SubKey(proofNode.Edge.Path.Len())
case *Edge:
subKey, err := remainingPath.SubKey(proofNode.Path.Len())
if err != nil {
return false
}
Expand All @@ -342,11 +345,11 @@
return true
}

if !proofNode.Edge.Path.Equal(subKey) {
if !proofNode.Path.Equal(subKey) {
return false
}
expectedHash = proofNode.Edge.Child
remainingPath.Truncate(251 - proofNode.Edge.Path.Len()) //nolint:mnd
expectedHash = proofNode.Child
remainingPath.Truncate(251 - proofNode.Path.Len()) //nolint:mnd
kirugan marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -438,27 +441,33 @@

// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key
func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) {
parent := &proofNodes[idx]
parent := proofNodes[idx]

if idx == len(proofNodes)-1 {
if parent.Edge != nil {
if _, ok := parent.(*Edge); ok {
return 1, parent.Len(), nil
}
return 0, parent.Len(), nil
}

child := &proofNodes[idx+1]

switch {
case parent.Edge != nil && child.Binary != nil:
return 1, parent.Edge.Path.len, nil
case parent.Binary != nil && child.Edge != nil:
child := proofNodes[idx+1]
_, isChildBinary := child.(*Binary)
isChildEdge := !isChildBinary
switch parent := parent.(type) {
case *Edge:
if isChildEdge {
break

Check warning on line 459 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L459

Added line #L459 was not covered by tests
}
return 1, parent.Len(), nil
case *Binary:
if isChildBinary {
break
}
childHash := child.Hash(hashF)
if parent.Binary.LeftHash.Equal(childHash) || parent.Binary.RightHash.Equal(childHash) {
return 1, child.Edge.Path.len, nil
} else {
return 0, 0, errors.New("can't determine the child hash from the parent and child")
if parent.LeftHash.Equal(childHash) || parent.RightHash.Equal(childHash) {
return 1, child.Len(), nil
}
return 0, 0, ErrChildHashNotFound

Check warning on line 470 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L470

Added line #L470 was not covered by tests
}

return 0, 1, nil
Expand Down Expand Up @@ -539,6 +548,7 @@
break
}
}

return pathNodes, nil
}

Expand All @@ -558,14 +568,20 @@
}

func getLeftRightHash(parentInd int, proofNodes []ProofNode) (*felt.Felt, *felt.Felt, error) {
parent := &proofNodes[parentInd]
if parent.Binary == nil {
parent := proofNodes[parentInd]

switch parent := parent.(type) {
case *Binary:
return parent.LeftHash, parent.RightHash, nil
case *Edge:
if parentInd+1 > len(proofNodes)-1 {
return nil, nil, errors.New("cant get hash of children from proof node, out of range")
}
parent = &proofNodes[parentInd+1]
parentBinary := proofNodes[parentInd+1].(*Binary)
return parentBinary.LeftHash, parentBinary.RightHash, nil
default:
return nil, nil, fmt.Errorf("%w: %T", ErrUnknownProofNode, parent)

Check warning on line 583 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L582-L583

Added lines #L582 - L583 were not covered by tests
}
return parent.Binary.LeftHash, parent.Binary.RightHash, nil
}

func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key,
Expand All @@ -576,16 +592,14 @@

var height uint8
if len(pathNodes) > 0 {
if proofNodes[idx].Edge != nil {
height = pathNodes[len(pathNodes)-1].key.len + proofNodes[idx].Edge.Path.len
if p, ok := proofNodes[idx].(*Edge); ok {
height = pathNodes[len(pathNodes)-1].key.len + p.Path.len

Check warning on line 596 in core/trie/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie/proof.go#L596

Added line #L596 was not covered by tests
} else {
height = pathNodes[len(pathNodes)-1].key.len + 1
}
} else {
height = 0
}

if pNode.Binary != nil {
if _, ok := pNode.(*Binary); ok {
crntKey, err = leafKey.SubKey(height)
} else {
crntKey, err = leafKey.SubKey(height + compressedParentOffset)
Expand Down
Loading