Skip to content

Commit

Permalink
Mypy updates
Browse files Browse the repository at this point in the history
Signed-off-by: cyc60 <avsysoev60@gmail.com>
  • Loading branch information
cyc60 committed Sep 14, 2023
1 parent bac363a commit f6ff288
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 185 deletions.
34 changes: 17 additions & 17 deletions multiproof/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
from typing import Any, List
from typing import Any

from web3 import Web3

Expand All @@ -9,9 +9,9 @@

@dataclass
class MultiProof:
leaves: List[Any]
proof: List[Any]
proof_flags: List[bool]
leaves: list[Any]
proof: list[Any]
proof_flags: list[bool]


def hash_pair(a: bytes, b: bytes) -> bytes:
Expand Down Expand Up @@ -41,33 +41,33 @@ def sibling_index(i: int) -> int:
raise ValueError('Root has no siblings')


def is_tree_node(tree: List[Any], i: int) -> bool:
def is_tree_node(tree: list[Any], i: int) -> bool:
return 0 <= i < len(tree)


def is_internal_node(tree: List[Any], i: int) -> bool:
def is_internal_node(tree: list[Any], i: int) -> bool:
return is_tree_node(tree, left_child_index(i))


def is_leaf_node(tree: List[Any], i: int) -> bool:
def is_leaf_node(tree: list[Any], i: int) -> bool:
return is_tree_node(tree, i) and not is_internal_node(tree, i)


def is_valid_merkle_node(node: bytes) -> bool:
return len(node) == 32


def check_tree_node(tree: List[Any], i: int) -> None:
def check_tree_node(tree: list[Any], i: int) -> None:
if not is_tree_node(tree, i):
raise ValueError("Index is not in tree")


def check_internal_node(tree: List[Any], i: int) -> None:
def check_internal_node(tree: list[Any], i: int) -> None:
if not is_internal_node(tree, i):
raise ValueError("Index is not an internal tree node")


def check_leaf_node(tree: List[Any], i: int) -> None:
def check_leaf_node(tree: list[Any], i: int) -> None:
if not is_leaf_node(tree, i):
raise ValueError("Index is not a leaf")

Expand All @@ -77,14 +77,14 @@ def check_valid_merkle_node(node: bytes) -> None:
raise ValueError("Merkle tree nodes must be Uint8Array of length 32")


def make_merkle_tree(leaves: List[bytes]) -> List[bytes]:
def make_merkle_tree(leaves: list[bytes]) -> list[bytes]:
for leaf in leaves:
check_valid_merkle_node(leaf)

if len(leaves) == 0:
raise ValueError("Expected non-zero number of leaves")

tree: List[bytes] = [b''] * (2 * len(leaves) - 1)
tree: list[bytes] = [b''] * (2 * len(leaves) - 1)

for index, leaf in enumerate(leaves):
tree[len(tree) - 1 - index] = leaf
Expand All @@ -97,7 +97,7 @@ def make_merkle_tree(leaves: List[bytes]) -> List[bytes]:
return tree


def get_proof(tree: List[bytes], index: int) -> List[bytes]:
def get_proof(tree: list[bytes], index: int) -> list[bytes]:
check_leaf_node(tree, index)

proof = []
Expand All @@ -108,7 +108,7 @@ def get_proof(tree: List[bytes], index: int) -> List[bytes]:
return proof


def process_proof(leaf: bytes, proof: List[bytes]) -> bytes:
def process_proof(leaf: bytes, proof: list[bytes]) -> bytes:
check_valid_merkle_node(leaf)
for item in proof:
check_valid_merkle_node(item)
Expand All @@ -118,7 +118,7 @@ def process_proof(leaf: bytes, proof: List[bytes]) -> bytes:
return result


def get_multi_proof(tree: List[bytes], indices: List[int]) -> MultiProof:
def get_multi_proof(tree: list[bytes], indices: list[int]) -> MultiProof:
for index in indices:
check_leaf_node(tree, index)

Expand Down Expand Up @@ -183,7 +183,7 @@ def process_multi_proof(multiproof: MultiProof) -> bytes:
return pop_safe(stack) or proof.pop(0)


def is_valid_merkle_tree(tree: List[bytes]) -> bool:
def is_valid_merkle_tree(tree: list[bytes]) -> bool:
for i, node in enumerate(tree):
if not is_valid_merkle_node(node):
return False
Expand All @@ -200,7 +200,7 @@ def is_valid_merkle_tree(tree: List[bytes]) -> bool:
return len(tree) > 0


def pop_safe(array: List[Any]) -> Any:
def pop_safe(array: list[Any]) -> Any:
try:
return array.pop()
except IndexError:
Expand Down
38 changes: 19 additions & 19 deletions multiproof/standart.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from functools import cmp_to_key
from typing import Any, Dict, List, Union
from typing import Any

from eth_abi import encode as abi_encode
from web3 import Web3
Expand All @@ -21,9 +21,9 @@ class LeafValue:

@dataclass
class StandardMerkleTreeData:
tree: List[str]
values: List[LeafValue]
leaf_encoding: List[str]
tree: list[str]
values: list[LeafValue]
leaf_encoding: list[str]
format: str = 'standard-v1'


Expand All @@ -34,17 +34,17 @@ class HashedValue:
hash: bytes


def standard_leaf_hash(values: Any, types: List[str]) -> bytes:
def standard_leaf_hash(values: Any, types: list[str]) -> bytes:
return Web3.keccak(Web3.keccak(abi_encode(types, values)))


class StandardMerkleTree:
_hash_lookup: Dict[str, int]
tree: List[bytes]
values: List[LeafValue]
leaf_encoding: List[str]
_hash_lookup: dict[str, int]
tree: list[bytes]
values: list[LeafValue]
leaf_encoding: list[str]

def __init__(self, tree: List[bytes], values: List[LeafValue], leaf_encoding: List[str]):
def __init__(self, tree: list[bytes], values: list[LeafValue], leaf_encoding: list[str]):
self.tree = tree
self.values = values
self.leaf_encoding = leaf_encoding
Expand All @@ -53,8 +53,8 @@ def __init__(self, tree: List[bytes], values: List[LeafValue], leaf_encoding: Li
self._hash_lookup[to_hex(standard_leaf_hash(leaf_value.value, leaf_encoding))] = index

@staticmethod
def of(values: List[Any], leaf_encoding: List[str]):
hashed_values: List[HashedValue] = []
def of(values: list[Any], leaf_encoding: list[str]) -> 'StandardMerkleTree':
hashed_values: list[HashedValue] = []
for index, value in enumerate(values):
hashed_values.append(
HashedValue(value=value, index=index, hash=standard_leaf_hash(value, leaf_encoding))
Expand Down Expand Up @@ -122,16 +122,16 @@ def validate(self) -> None:
if not is_valid_merkle_tree(self.tree):
raise ValueError("Merkle tree is invalid")

def leaf_hash(self, leaf) -> str:
def leaf_hash(self, leaf: Any) -> str:
return to_hex(standard_leaf_hash(leaf, self.leaf_encoding))

def leaf_lookup(self, leaf) -> int:
def leaf_lookup(self, leaf: Any) -> int:
v = self._hash_lookup[self.leaf_hash(leaf)]
if v is None:
raise ValueError("Leaf is not in tree")
return v

def get_proof(self, leaf: Union[LeafValue, int]) -> List[str]:
def get_proof(self, leaf: LeafValue | int) -> list[str]:
# input validity
value_index: int = leaf # type: ignore
if not isinstance(leaf, int):
Expand All @@ -151,7 +151,7 @@ def get_proof(self, leaf: Union[LeafValue, int]) -> List[str]:

return [to_hex(p) for p in proof]

def get_multi_proof(self, leaves) -> MultiProof:
def get_multi_proof(self, leaves: list[Any]) -> MultiProof:
# input validity
value_indices = []
for leaf in leaves:
Expand Down Expand Up @@ -183,7 +183,7 @@ def get_multi_proof(self, leaves) -> MultiProof:
def verify_leaf(self, leaf: int, proof: list[str]) -> bool:
return self._verify_leaf(self._get_leaf_hash(leaf), [hex_to_bytes(p) for p in proof])

def _verify_leaf(self, leaf_hash: bytes, proof: List[bytes]) -> bool:
def _verify_leaf(self, leaf_hash: bytes, proof: list[bytes]) -> bool:
implied_root = process_proof(leaf_hash, proof)
return equals_bytes(implied_root, self.tree[0])

Expand Down Expand Up @@ -221,8 +221,8 @@ def __str__(self):
if len(self.tree) == 0:
raise ValueError("Expected non-zero number of nodes")

stack: List = [[0, []]]
lines: List = []
stack: list = [[0, []]]
lines: list = []

while len(stack) > 0:
i, path = stack.pop()
Expand Down
2 changes: 1 addition & 1 deletion multiproof/tests/test_standart.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ZERO = to_hex(ZERO_BYTES)


def characters(s: str):
def characters(s: str) -> tuple[list[list[str]], StandardMerkleTree]:
l = [[x] for x in s]
tree = StandardMerkleTree.of(l, ['string'])
return l, tree
Expand Down
5 changes: 1 addition & 4 deletions multiproof/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from typing import Any, List


def check_bounds(array: List[Any], index: int) -> None:
def check_bounds(array: list, index: int) -> None:
if index < 0 or index >= len(array):
raise ValueError("Index out of bounds")
Loading

0 comments on commit f6ff288

Please sign in to comment.