diff --git a/Cargo.lock b/Cargo.lock index 19c750cfe2d..2ad43adc9b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3547,7 +3547,7 @@ dependencies = [ [[package]] name = "mithril-aggregator" -version = "0.5.41" +version = "0.5.42" dependencies = [ "anyhow", "async-trait", @@ -3848,7 +3848,7 @@ dependencies = [ [[package]] name = "mithril-signer" -version = "0.2.163" +version = "0.2.164" dependencies = [ "anyhow", "async-trait", diff --git a/mithril-aggregator/Cargo.toml b/mithril-aggregator/Cargo.toml index abb68c956ba..576a9674435 100644 --- a/mithril-aggregator/Cargo.toml +++ b/mithril-aggregator/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-aggregator" -version = "0.5.41" +version = "0.5.42" description = "A Mithril Aggregator server" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-aggregator/src/services/cardano_transactions_importer.rs b/mithril-aggregator/src/services/cardano_transactions_importer.rs index 1e12654ced6..996e64d0b1f 100644 --- a/mithril-aggregator/src/services/cardano_transactions_importer.rs +++ b/mithril-aggregator/src/services/cardano_transactions_importer.rs @@ -3,6 +3,7 @@ use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; +use anyhow::Context; use async_trait::async_trait; use slog::{debug, Logger}; use tokio::{runtime::Handle, task}; @@ -51,6 +52,7 @@ pub trait TransactionStore: Send + Sync { } /// Import and store [CardanoTransaction]. +#[derive(Clone)] pub struct CardanoTransactionsImporter { block_scanner: Arc, transaction_store: Arc, @@ -174,31 +176,29 @@ impl CardanoTransactionsImporter { .store_block_range_roots(block_ranges_with_merkle_root) .await } - - async fn import_transactions_and_block_ranges( - &self, - up_to_beacon: BlockNumber, - ) -> StdResult<()> { - self.import_transactions(up_to_beacon).await?; - self.import_block_ranges(up_to_beacon).await - } } #[async_trait] impl TransactionsImporter for CardanoTransactionsImporter { async fn import(&self, up_to_beacon: BlockNumber) -> StdResult<()> { - task::block_in_place(move || { + let importer = self.clone(); + task::spawn_blocking(move || { Handle::current().block_on(async move { - self.import_transactions_and_block_ranges(up_to_beacon) - .await + importer.import_transactions(up_to_beacon).await?; + importer.import_block_ranges(up_to_beacon).await?; + Ok(()) }) }) + .await + .with_context(|| "TransactionsImporter - worker thread crashed")? } } #[cfg(test)] mod tests { - use mithril_persistence::sqlite::SqliteConnectionPool; + use std::sync::atomic::AtomicUsize; + use std::time::Duration; + use mockall::mock; use mithril_common::cardano_block_scanner::{ @@ -207,6 +207,7 @@ mod tests { use mithril_common::crypto_helper::MKTree; use mithril_common::entities::{BlockNumber, BlockRangesSequence}; use mithril_persistence::database::repository::CardanoTransactionRepository; + use mithril_persistence::sqlite::SqliteConnectionPool; use crate::database::test_helper::cardano_tx_db_connection; use crate::test_tools::TestLogger; @@ -717,7 +718,7 @@ mod tests { ); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn importing_twice_starting_with_nothing_in_a_real_db_should_yield_transactions_in_same_order( ) { let blocks = vec![ @@ -754,7 +755,7 @@ mod tests { assert_eq!(cold_imported_transactions, warm_imported_transactions); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn when_rollbackward_should_remove_transactions() { let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new( @@ -797,7 +798,7 @@ mod tests { assert_eq!(expected_remaining_transactions, stored_transactions); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn when_rollbackward_should_remove_block_ranges() { let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new( @@ -869,4 +870,92 @@ mod tests { .collect::>() ); } + + #[tokio::test] + async fn test_import_is_non_blocking() { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + static MAX_COUNTER: usize = 25; + static WAIT_TIME: u64 = 50; + + // Use a local set to ensure the counter task is not dispatched on a different thread + let local = task::LocalSet::new(); + local + .run_until(async { + let importer = CardanoTransactionsImporter::new_for_test( + Arc::new(DumbBlockScanner::new()), + Arc::new(BlockingRepository { + wait_time: Duration::from_millis(WAIT_TIME), + }), + ); + + let importer_future = importer.import(100); + let counter_task = task::spawn_local(async { + while COUNTER.load(std::sync::atomic::Ordering::SeqCst) < MAX_COUNTER { + tokio::time::sleep(Duration::from_millis(1)).await; + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + }); + importer_future.await.unwrap(); + + counter_task.abort(); + }) + .await; + + assert_eq!( + MAX_COUNTER, + COUNTER.load(std::sync::atomic::Ordering::SeqCst) + ); + + struct BlockingRepository { + wait_time: Duration, + } + + impl BlockingRepository { + fn block_thread(&self) { + std::thread::sleep(self.wait_time); + } + } + + #[async_trait] + impl TransactionStore for BlockingRepository { + async fn get_highest_beacon(&self) -> StdResult> { + self.block_thread(); + Ok(None) + } + + async fn get_highest_block_range(&self) -> StdResult> { + self.block_thread(); + Ok(None) + } + + async fn store_transactions(&self, _: Vec) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + + async fn get_transactions_in_range( + &self, + _: Range, + ) -> StdResult> { + self.block_thread(); + Ok(vec![]) + } + + async fn store_block_range_roots( + &self, + _: Vec<(BlockRange, MKTreeNode)>, + ) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + + async fn remove_rolled_back_transactions_and_block_range( + &self, + _: SlotNumber, + ) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + } + } } diff --git a/mithril-aggregator/tests/create_certificate.rs b/mithril-aggregator/tests/create_certificate.rs index 046a20a96aa..47a367efd8c 100644 --- a/mithril-aggregator/tests/create_certificate.rs +++ b/mithril-aggregator/tests/create_certificate.rs @@ -10,7 +10,7 @@ use mithril_common::{ }; use test_extensions::{utilities::get_test_dir, ExpectedCertificate, RuntimeTester}; -#[tokio::test(flavor = "multi_thread")] +#[tokio::test] async fn create_certificate() { let protocol_parameters = ProtocolParameters { k: 5, diff --git a/mithril-aggregator/tests/prove_transactions.rs b/mithril-aggregator/tests/prove_transactions.rs index 9e079f4099e..761710e350b 100644 --- a/mithril-aggregator/tests/prove_transactions.rs +++ b/mithril-aggregator/tests/prove_transactions.rs @@ -13,7 +13,7 @@ use crate::test_extensions::utilities::tx_hash; mod test_extensions; -#[tokio::test(flavor = "multi_thread")] +#[tokio::test] async fn prove_transactions() { let protocol_parameters = ProtocolParameters { k: 5, diff --git a/mithril-signer/Cargo.toml b/mithril-signer/Cargo.toml index 7deeedaeafe..f37a02f657b 100644 --- a/mithril-signer/Cargo.toml +++ b/mithril-signer/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-signer" -version = "0.2.163" +version = "0.2.164" description = "A Mithril Signer" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-signer/src/cardano_transactions_importer.rs b/mithril-signer/src/cardano_transactions_importer.rs index 1e12654ced6..996e64d0b1f 100644 --- a/mithril-signer/src/cardano_transactions_importer.rs +++ b/mithril-signer/src/cardano_transactions_importer.rs @@ -3,6 +3,7 @@ use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; +use anyhow::Context; use async_trait::async_trait; use slog::{debug, Logger}; use tokio::{runtime::Handle, task}; @@ -51,6 +52,7 @@ pub trait TransactionStore: Send + Sync { } /// Import and store [CardanoTransaction]. +#[derive(Clone)] pub struct CardanoTransactionsImporter { block_scanner: Arc, transaction_store: Arc, @@ -174,31 +176,29 @@ impl CardanoTransactionsImporter { .store_block_range_roots(block_ranges_with_merkle_root) .await } - - async fn import_transactions_and_block_ranges( - &self, - up_to_beacon: BlockNumber, - ) -> StdResult<()> { - self.import_transactions(up_to_beacon).await?; - self.import_block_ranges(up_to_beacon).await - } } #[async_trait] impl TransactionsImporter for CardanoTransactionsImporter { async fn import(&self, up_to_beacon: BlockNumber) -> StdResult<()> { - task::block_in_place(move || { + let importer = self.clone(); + task::spawn_blocking(move || { Handle::current().block_on(async move { - self.import_transactions_and_block_ranges(up_to_beacon) - .await + importer.import_transactions(up_to_beacon).await?; + importer.import_block_ranges(up_to_beacon).await?; + Ok(()) }) }) + .await + .with_context(|| "TransactionsImporter - worker thread crashed")? } } #[cfg(test)] mod tests { - use mithril_persistence::sqlite::SqliteConnectionPool; + use std::sync::atomic::AtomicUsize; + use std::time::Duration; + use mockall::mock; use mithril_common::cardano_block_scanner::{ @@ -207,6 +207,7 @@ mod tests { use mithril_common::crypto_helper::MKTree; use mithril_common::entities::{BlockNumber, BlockRangesSequence}; use mithril_persistence::database::repository::CardanoTransactionRepository; + use mithril_persistence::sqlite::SqliteConnectionPool; use crate::database::test_helper::cardano_tx_db_connection; use crate::test_tools::TestLogger; @@ -717,7 +718,7 @@ mod tests { ); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn importing_twice_starting_with_nothing_in_a_real_db_should_yield_transactions_in_same_order( ) { let blocks = vec![ @@ -754,7 +755,7 @@ mod tests { assert_eq!(cold_imported_transactions, warm_imported_transactions); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn when_rollbackward_should_remove_transactions() { let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new( @@ -797,7 +798,7 @@ mod tests { assert_eq!(expected_remaining_transactions, stored_transactions); } - #[tokio::test(flavor = "multi_thread")] + #[tokio::test] async fn when_rollbackward_should_remove_block_ranges() { let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new( @@ -869,4 +870,92 @@ mod tests { .collect::>() ); } + + #[tokio::test] + async fn test_import_is_non_blocking() { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + static MAX_COUNTER: usize = 25; + static WAIT_TIME: u64 = 50; + + // Use a local set to ensure the counter task is not dispatched on a different thread + let local = task::LocalSet::new(); + local + .run_until(async { + let importer = CardanoTransactionsImporter::new_for_test( + Arc::new(DumbBlockScanner::new()), + Arc::new(BlockingRepository { + wait_time: Duration::from_millis(WAIT_TIME), + }), + ); + + let importer_future = importer.import(100); + let counter_task = task::spawn_local(async { + while COUNTER.load(std::sync::atomic::Ordering::SeqCst) < MAX_COUNTER { + tokio::time::sleep(Duration::from_millis(1)).await; + COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + } + }); + importer_future.await.unwrap(); + + counter_task.abort(); + }) + .await; + + assert_eq!( + MAX_COUNTER, + COUNTER.load(std::sync::atomic::Ordering::SeqCst) + ); + + struct BlockingRepository { + wait_time: Duration, + } + + impl BlockingRepository { + fn block_thread(&self) { + std::thread::sleep(self.wait_time); + } + } + + #[async_trait] + impl TransactionStore for BlockingRepository { + async fn get_highest_beacon(&self) -> StdResult> { + self.block_thread(); + Ok(None) + } + + async fn get_highest_block_range(&self) -> StdResult> { + self.block_thread(); + Ok(None) + } + + async fn store_transactions(&self, _: Vec) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + + async fn get_transactions_in_range( + &self, + _: Range, + ) -> StdResult> { + self.block_thread(); + Ok(vec![]) + } + + async fn store_block_range_roots( + &self, + _: Vec<(BlockRange, MKTreeNode)>, + ) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + + async fn remove_rolled_back_transactions_and_block_range( + &self, + _: SlotNumber, + ) -> StdResult<()> { + self.block_thread(); + Ok(()) + } + } + } } diff --git a/mithril-signer/tests/create_cardano_transaction_single_signature.rs b/mithril-signer/tests/create_cardano_transaction_single_signature.rs index 52cb5710b02..42b19245755 100644 --- a/mithril-signer/tests/create_cardano_transaction_single_signature.rs +++ b/mithril-signer/tests/create_cardano_transaction_single_signature.rs @@ -9,7 +9,7 @@ use mithril_common::{ use test_extensions::StateMachineTester; #[rustfmt::skip] -#[tokio::test(flavor = "multi_thread")] +#[tokio::test] async fn test_create_cardano_transaction_single_signature() { let protocol_parameters = tests_setup::setup_protocol_parameters(); let fixture = MithrilFixtureBuilder::default() diff --git a/mithril-test-lab/mithril-end-to-end/src/bin/load-aggregator/main.rs b/mithril-test-lab/mithril-end-to-end/src/bin/load-aggregator/main.rs index 0f472191c11..fc4082fe1b2 100644 --- a/mithril-test-lab/mithril-end-to-end/src/bin/load-aggregator/main.rs +++ b/mithril-test-lab/mithril-end-to-end/src/bin/load-aggregator/main.rs @@ -29,7 +29,7 @@ fn init_logger(opts: &MainOpts) -> slog_scope::GlobalLoggerGuard { slog_scope::set_global_logger(slog::Logger::root(Arc::new(drain), slog::o!())) } -#[tokio::main(flavor = "multi_thread")] +#[tokio::main] async fn main() -> StdResult<()> { let opts = MainOpts::parse(); let mut reporter: Reporter = Reporter::new(opts.num_signers, opts.num_clients);