Skip to content

Commit

Permalink
Merge pull request #685 from Chia-Network/stubtest
Browse files Browse the repository at this point in the history
add stubtest
  • Loading branch information
altendky authored Sep 25, 2024
2 parents 9b523c7 + c4087fc commit 2e3d3ba
Show file tree
Hide file tree
Showing 7 changed files with 802 additions and 680 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ jobs:
os: [macos-latest, ubuntu-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]

env:
stubtest_args: ${{ matrix.python-version == '3.11' && '--allowlist wheel/stubtest.allowlist.3-11-plus' || ''}}

steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -64,6 +67,26 @@ jobs:
run: |
mypy --ignore-missing-imports tests
- name: python mypy stubtest
shell: bash
run: |
FAILURE=0
echo "::group::concise"
if ! stubtest ${{ env.stubtest_args }} --allowlist wheel/stubtest.allowlist --concise chia_rs
then
FAILURE=1
fi
echo "::endgroup::"
echo "::group::complete"
if ! stubtest ${{ env.stubtest_args }} --allowlist wheel/stubtest.allowlist chia_rs
then
FAILURE=1
fi
echo "::endgroup::"
exit ${FAILURE}
- name: python black
run: |
black --check tests
Expand Down
38 changes: 20 additions & 18 deletions crates/chia_py_streamable_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,28 +127,30 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS
}
});

py_protocol.extend(quote! {
#[pyo3::pymethods]
impl #ident {
#[pyo3(signature = (**kwargs))]
fn replace(&self, kwargs: Option<&pyo3::types::PyDict>) -> pyo3::PyResult<Self> {
let mut ret = self.clone();
if let Some(kwargs) = kwargs {
let iter: pyo3::types::iter::PyDictIterator = kwargs.iter();
for (field, value) in iter {
let field = field.extract::<String>()?;
match field.as_str() {
#(stringify!(#fnames_maybe_upper) => {
ret.#fnames = value.extract()?;
}),*
_ => { return Err(pyo3::exceptions::PyKeyError::new_err(format!("unknown field {field}"))); }
if !named.is_empty() {
py_protocol.extend(quote! {
#[pyo3::pymethods]
impl #ident {
#[pyo3(signature = (**kwargs))]
fn replace(&self, kwargs: Option<&pyo3::types::PyDict>) -> pyo3::PyResult<Self> {
let mut ret = self.clone();
if let Some(kwargs) = kwargs {
let iter: pyo3::types::iter::PyDictIterator = kwargs.iter();
for (field, value) in iter {
let field = field.extract::<String>()?;
match field.as_str() {
#(stringify!(#fnames_maybe_upper) => {
ret.#fnames = value.extract()?;
}),*
_ => { return Err(pyo3::exceptions::PyKeyError::new_err(format!("unknown field {field}"))); }
}
}
}
Ok(ret)
}
Ok(ret)
}
}
});
});
}
}
syn::Fields::Unnamed(FieldsUnnamed { .. }) => {}
syn::Fields::Unit => {
Expand Down
60 changes: 36 additions & 24 deletions wheel/generate_type_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,37 @@ def add_indent(x: str):
all_replace_parameters = []
for m in members:
replace_param_name, replace_type = m.split(":")
if replace_param_name.startswith("a") and replace_param_name[1:].isnumeric():
continue
all_replace_parameters.append(
f"{replace_param_name}: Union[{replace_type}, _Unspec] = _Unspec()"
)

if extra is not None:
members.extend(extra)

# TODO: could theoretically be detected from the use of #[streamable(subclass)]
inheritable = name in ["SpendBundle"]

# TODO: is __richcmp__ ever actually present?
# def __richcmp__(self) -> Any: ...
file.write(
f"""
{"" if inheritable else "@final"}
class {name}:{"".join(map(add_indent, members))}
def __init__(
self{init_args}
) -> None: ...
def __hash__(self) -> int: ...
def __repr__(self) -> str: ...
def __richcmp__(self) -> Any: ...
def __deepcopy__(self) -> {name}: ...
def __deepcopy__(self, memo: object) -> {name}: ...
def __copy__(self) -> {name}: ...
@classmethod
def from_bytes(cls, bytes) -> Self: ...
def from_bytes(cls, blob: bytes) -> Self: ...
@classmethod
def from_bytes_unchecked(cls, bytes) -> Self: ...
def from_bytes_unchecked(cls, blob: bytes) -> Self: ...
@classmethod
def parse_rust(cls, ReadableBuffer, bool = False) -> Tuple[Self, int]: ...
def parse_rust(cls, blob: ReadableBuffer, trusted: bool = False) -> Tuple[Self, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def stream_to_bytes(self) -> bytes: ...
Expand Down Expand Up @@ -159,6 +166,7 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
extra_members = {
"Coin": [
"def name(self) -> bytes32: ...",
"@classmethod\n def from_parent(cls, _coin: Self): ...",
],
"ClassgroupElement": [
"@staticmethod\n def create(bytes) -> ClassgroupElement: ...",
Expand Down Expand Up @@ -207,7 +215,7 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
"Program": [
"def get_tree_hash(self) -> bytes32: ...",
"@staticmethod\n def default() -> Program: ...",
"@staticmethod\n def fromhex(hex) -> Program: ...",
"@staticmethod\n def fromhex(h: str) -> Program: ...",
"def run_mempool_with_cost(self, max_cost: int, args: object) -> Tuple[int, ChiaProgram]: ...",
"def run_with_cost(self, max_cost: int, args: object) -> Tuple[int, ChiaProgram]: ...",
"def _run(self, max_cost: int, flags: int, args: object) -> Tuple[int, ChiaProgram]: ...",
Expand All @@ -217,7 +225,8 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
"def uncurry(self) -> Tuple[ChiaProgram, ChiaProgram]: ...",
],
"SpendBundle": [
"@classmethod\n def aggregate(cls, sbs: List[SpendBundle]) -> Self: ...",
"@classmethod\n def from_parent(cls, spend_bundle: Self): ...",
"@classmethod\n def aggregate(cls, spend_bundles: List[SpendBundle]) -> Self: ...",
"def name(self) -> bytes32: ...",
"def removals(self) -> List[Coin]: ...",
"def additions(self) -> List[Coin]: ...",
Expand All @@ -232,6 +241,9 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
"def ip_iters(self, constants: ConsensusConstants) -> uint64: ...",
"def sp_total_iters(self, constants: ConsensusConstants) -> uint128: ...",
],
"CoinSpend": [
"@classmethod\n def from_parent(cls, cs: Self): ...",
],
}

classes = []
Expand All @@ -254,7 +266,7 @@ def parse_rust_source(filename: str, upper_case: bool) -> List[Tuple[str, List[s
# this file is generated by generate_type_stubs.py
#
from typing import List, Optional, Sequence, Tuple, Union, Dict, Any, ClassVar
from typing import List, Optional, Sequence, Tuple, Union, Dict, Any, ClassVar, final
from .sized_bytes import bytes32, bytes100
from .sized_ints import uint8, uint16, uint32, uint64, uint128, int8, int16, int32, int64
from typing_extensions import Self
Expand All @@ -268,27 +280,23 @@ class _Unspec:
def solution_generator(spends: Sequence[Tuple[Coin, bytes, bytes]]) -> bytes: ...
def solution_generator_backrefs(spends: Sequence[Tuple[Coin, bytes, bytes]]) -> bytes: ...
def compute_merkle_set_root(items: Sequence[bytes]) -> bytes: ...
def compute_merkle_set_root(values: Sequence[bytes]) -> bytes: ...
def supports_fast_forward(spend: CoinSpend) -> bool : ...
def fast_forward_singleton(spend: CoinSpend, new_coin: Coin, new_parent: Coin) -> bytes: ...
def run_block_generator(
program: ReadableBuffer, args: List[ReadableBuffer], max_cost: int, flags: int, constants: ConsensusConstants
program: ReadableBuffer, block_refs: List[ReadableBuffer], max_cost: int, flags: int, constants: ConsensusConstants
) -> Tuple[Optional[int], Optional[SpendBundleConditions]]: ...
def run_block_generator2(
program: ReadableBuffer, args: List[ReadableBuffer], max_cost: int, flags: int, constants: ConsensusConstants
program: ReadableBuffer, block_refs: List[ReadableBuffer], max_cost: int, flags: int, constants: ConsensusConstants
) -> Tuple[Optional[int], Optional[SpendBundleConditions]]: ...
def run_puzzle(
puzzle: bytes, solution: bytes, parent_id: bytes32, amount: int, max_cost: int, flags: int, constants: ConsensusConstants
) -> SpendBundleConditions: ...
def deserialize_proof(
proof: bytes
) -> MerkleSet: ...
def confirm_included_already_hashed(
root: bytes32,
item: bytes32,
Expand Down Expand Up @@ -336,22 +344,25 @@ def run_chia_program(
program: bytes, args: bytes, max_cost: int, flags: int
) -> Tuple[int, LazyNode]: ...
@final
class LazyNode:
pair: Optional[Tuple[LazyNode, LazyNode]]
atom: Optional[bytes]
def serialized_length(program: ReadableBuffer) -> int: ...
def tree_hash(program: ReadableBuffer) -> bytes32: ...
def tree_hash(blob: ReadableBuffer) -> bytes32: ...
def get_puzzle_and_solution_for_coin(program: ReadableBuffer, args: ReadableBuffer, max_cost: int, find_parent: bytes32, find_amount: int, find_ph: bytes32, flags: int) -> Tuple[bytes, bytes]: ...
def get_puzzle_and_solution_for_coin2(program: Program, block_refs: List[ReadableBuffer], max_cost: int, find_coin: Coin, flags: int) -> Tuple[Program, Program]: ...
def get_puzzle_and_solution_for_coin2(generator: Program, block_refs: List[ReadableBuffer], max_cost: int, find_coin: Coin, flags: int) -> Tuple[Program, Program]: ...
@final
class BLSCache:
def __init__(self, cache_size: Optional[int] = 50000) -> None: ...
def len(self) -> int: ...
def aggregate_verify(self, pks: List[G1Element], msgs: List[bytes], sig: G2Element) -> bool: ...
def items(self) -> List[Tuple[bytes, GTElement]]: ...
def update(self, other: Sequence[Tuple[bytes, GTElement]]) -> None: ...
@final
class AugSchemeMPL:
@staticmethod
def sign(pk: PrivateKey, msg: bytes, prepend_pk: Optional[G1Element] = None) -> G2Element: ...
Expand All @@ -366,15 +377,16 @@ def key_gen(seed: bytes) -> PrivateKey: ...
@staticmethod
def g2_from_message(msg: bytes) -> G2Element: ...
@staticmethod
def derive_child_sk(pk: PrivateKey, index: int) -> PrivateKey: ...
def derive_child_sk(sk: PrivateKey, index: int) -> PrivateKey: ...
@staticmethod
def derive_child_sk_unhardened(pk: PrivateKey, index: int) -> PrivateKey: ...
def derive_child_sk_unhardened(sk: PrivateKey, index: int) -> PrivateKey: ...
@staticmethod
def derive_child_pk_unhardened(pk: G1Element, index: int) -> G1Element: ...
@final
class MerkleSet:
def get_root(self) -> bytes32: ...
def is_included_already_hashed(self, to_check: bytes) -> Tuple[bool, bytes]: ...
def is_included_already_hashed(self, included_leaf: bytes32) -> Tuple[bool, bytes]: ...
def __init__(
self,
leafs: List[bytes32],
Expand All @@ -397,7 +409,7 @@ def __init__(
"def __str__(self) -> str: ...",
"def __add__(self, other: G1Element) -> G1Element: ...",
"def __iadd__(self, other: G1Element) -> G1Element: ...",
"def derive_unhardened(self, int) -> G1Element: ...",
"def derive_unhardened(self, idx: int) -> G1Element: ...",
],
)
print_class(
Expand Down Expand Up @@ -436,10 +448,10 @@ def __init__(
"def get_g1(self) -> G1Element: ...",
"def __str__(self) -> str: ...",
"def public_key(self) -> G1Element: ...",
"def derive_hardened(self, int) -> PrivateKey: ...",
"def derive_unhardened(self, int) -> PrivateKey: ...",
"def derive_hardened(self, idx: int) -> PrivateKey: ...",
"def derive_unhardened(self, idx: int) -> PrivateKey: ...",
"@staticmethod",
"def from_seed(bytes) -> PrivateKey: ...",
"def from_seed(seed: bytes) -> PrivateKey: ...",
],
)

Expand Down
Loading

0 comments on commit 2e3d3ba

Please sign in to comment.