diff --git a/tests/test_blspy_fidelity.py b/tests/test_blspy_fidelity.py index 9af92566c..8e89a1dd4 100644 --- a/tests/test_blspy_fidelity.py +++ b/tests/test_blspy_fidelity.py @@ -4,6 +4,7 @@ import sys from typing import Any, Type import pytest +from concurrent.futures import ThreadPoolExecutor def randbytes(n: int) -> bytes: @@ -16,7 +17,8 @@ def randbytes(n: int) -> bytes: # make sure chia_rs counterpart behaves the same as blspy def test_bls() -> None: print() - for round in range(200): + + def run_test_suite(round: int) -> None: sys.stdout.write(f"\r{round} ") sys.stdout.flush() seed = randbytes(32) @@ -210,6 +212,11 @@ def test_bls() -> None: with pytest.raises(ValueError, match="invalid length"): obj2 = klass.from_json_dict(bytes(obj) + b"a") + pool = ThreadPoolExecutor(max_workers=8) + for round in range(200): + pool.submit(run_test_suite, round) + pool.shutdown() + # ------------------------------------- 8< ---------------------------------- # diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 51f661978..f100bd97c 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -292,12 +292,13 @@ impl AugSchemeMPL { } #[staticmethod] - pub fn verify(pk: &PublicKey, msg: &[u8], sig: &Signature) -> bool { - chia_bls::verify(sig, pk, msg) + pub fn verify(py: Python<'_>, pk: &PublicKey, msg: &[u8], sig: &Signature) -> bool { + py.allow_threads(|| chia_bls::verify(sig, pk, msg)) } #[staticmethod] pub fn aggregate_verify( + py: Python<'_>, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>, sig: &Signature, @@ -314,7 +315,7 @@ impl AugSchemeMPL { data.push((pk, msg)); } - Ok(chia_bls::aggregate_verify(sig, data)) + py.allow_threads(|| Ok(chia_bls::aggregate_verify(sig, data))) } #[staticmethod]