Skip to content

Commit

Permalink
Use Tokio spawn_blocking instead of block_in_place in transcation…
Browse files Browse the repository at this point in the history
… importers

This avoid the need of specifying `(flavor = "multi_thread")` for the
tests.
  • Loading branch information
Alenar committed Jul 8, 2024
1 parent 70d4fe6 commit 0774128
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 34 deletions.
120 changes: 105 additions & 15 deletions mithril-aggregator/src/services/cardano_transactions_importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;

use anyhow::Context;
use async_trait::async_trait;
use slog::{debug, Logger};
use tokio::{runtime::Handle, task};
Expand Down Expand Up @@ -53,6 +54,7 @@ pub trait TransactionStore: Send + Sync {
}

/// Import and store [CardanoTransaction].
#[derive(Clone)]
pub struct CardanoTransactionsImporter {
block_scanner: Arc<dyn BlockScanner>,
transaction_store: Arc<dyn TransactionStore>,
Expand Down Expand Up @@ -175,31 +177,29 @@ impl CardanoTransactionsImporter {
.store_block_range_roots(block_ranges_with_merkle_root)
.await
}

async fn import_transactions_and_block_ranges(
&self,
up_to_beacon: BlockNumber,
) -> StdResult<()> {
self.import_transactions(up_to_beacon).await?;
self.import_block_ranges().await
}
}

#[async_trait]
impl TransactionsImporter for CardanoTransactionsImporter {
async fn import(&self, up_to_beacon: BlockNumber) -> StdResult<()> {
task::block_in_place(move || {
let importer = self.clone();
task::spawn_blocking(move || {
Handle::current().block_on(async move {
self.import_transactions_and_block_ranges(up_to_beacon)
.await
importer.import_transactions(up_to_beacon).await?;
importer.import_block_ranges().await?;
Ok(())
})
})
.await
.with_context(|| "TransactionsImporter - worker thread crashed")?
}
}

#[cfg(test)]
mod tests {
use mithril_persistence::sqlite::SqliteConnectionPool;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;

use mockall::mock;

use mithril_common::cardano_block_scanner::{
Expand All @@ -208,6 +208,7 @@ mod tests {
use mithril_common::crypto_helper::MKTree;
use mithril_common::entities::{BlockNumber, BlockRangesSequence};
use mithril_persistence::database::repository::CardanoTransactionRepository;
use mithril_persistence::sqlite::SqliteConnectionPool;

use crate::database::test_helper::cardano_tx_db_connection;
use crate::test_tools::TestLogger;
Expand Down Expand Up @@ -652,7 +653,7 @@ mod tests {
);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn importing_twice_starting_with_nothing_in_a_real_db_should_yield_transactions_in_same_order(
) {
let blocks = vec![
Expand Down Expand Up @@ -689,7 +690,7 @@ mod tests {
assert_eq!(cold_imported_transactions, warm_imported_transactions);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn when_rollbackward_should_remove_transactions() {
let connection = cardano_tx_db_connection().unwrap();
let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(
Expand Down Expand Up @@ -732,7 +733,7 @@ mod tests {
assert_eq!(expected_remaining_transactions, stored_transactions);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn when_rollbackward_should_remove_block_ranges() {
let connection = cardano_tx_db_connection().unwrap();
let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(
Expand Down Expand Up @@ -804,4 +805,93 @@ mod tests {
.collect::<Vec<_>>()
);
}

#[tokio::test]
async fn test_import_is_non_blocking() {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
static MAX_COUNTER: usize = 25;
static WAIT_TIME: u64 = 50;

// Use a local set to ensure the counter task is not dispatched on a different thread
let local = task::LocalSet::new();
local
.run_until(async {
let counter_task = task::spawn_local(async {
while COUNTER.load(std::sync::atomic::Ordering::SeqCst) < MAX_COUNTER {
tokio::time::sleep(Duration::from_millis(1)).await;
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
});

let importer = CardanoTransactionsImporter::new_for_test(
Arc::new(DumbBlockScanner::new()),
Arc::new(BlockingRepository {
wait_time: Duration::from_millis(WAIT_TIME),
}),
);

importer.import(100).await.unwrap();
counter_task.abort();
})
.await;

assert_eq!(
MAX_COUNTER,
COUNTER.load(std::sync::atomic::Ordering::SeqCst)
);

struct BlockingRepository {
wait_time: Duration,
}

impl BlockingRepository {
fn block_thread(&self) {
std::thread::sleep(self.wait_time);
}
}

#[async_trait]
impl TransactionStore for BlockingRepository {
async fn get_highest_beacon(&self) -> StdResult<Option<ChainPoint>> {
self.block_thread();
Ok(None)
}

async fn store_transactions(&self, _: Vec<CardanoTransaction>) -> StdResult<()> {
self.block_thread();
Ok(())
}

async fn get_block_interval_without_block_range_root(
&self,
) -> StdResult<Option<Range<BlockNumber>>> {
self.block_thread();
Ok(None)
}

async fn get_transactions_in_range(
&self,
_: Range<BlockNumber>,
) -> StdResult<Vec<CardanoTransaction>> {
self.block_thread();
Ok(vec![])
}

async fn store_block_range_roots(
&self,
_: Vec<(BlockRange, MKTreeNode)>,
) -> StdResult<()> {
self.block_thread();
Ok(())
}

async fn remove_rolled_back_transactions_and_block_range(
&self,
_: SlotNumber,
) -> StdResult<()> {
self.block_thread();
Ok(())
}
}
}
}
2 changes: 1 addition & 1 deletion mithril-aggregator/tests/create_certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use mithril_common::{
};
use test_extensions::{utilities::get_test_dir, ExpectedCertificate, RuntimeTester};

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn create_certificate() {
let protocol_parameters = ProtocolParameters {
k: 5,
Expand Down
2 changes: 1 addition & 1 deletion mithril-aggregator/tests/prove_transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::test_extensions::utilities::tx_hash;

mod test_extensions;

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn prove_transactions() {
let protocol_parameters = ProtocolParameters {
k: 5,
Expand Down
120 changes: 105 additions & 15 deletions mithril-signer/src/cardano_transactions_importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;

use anyhow::Context;
use async_trait::async_trait;
use slog::{debug, Logger};
use tokio::{runtime::Handle, task};
Expand Down Expand Up @@ -53,6 +54,7 @@ pub trait TransactionStore: Send + Sync {
}

/// Import and store [CardanoTransaction].
#[derive(Clone)]
pub struct CardanoTransactionsImporter {
block_scanner: Arc<dyn BlockScanner>,
transaction_store: Arc<dyn TransactionStore>,
Expand Down Expand Up @@ -175,31 +177,29 @@ impl CardanoTransactionsImporter {
.store_block_range_roots(block_ranges_with_merkle_root)
.await
}

async fn import_transactions_and_block_ranges(
&self,
up_to_beacon: BlockNumber,
) -> StdResult<()> {
self.import_transactions(up_to_beacon).await?;
self.import_block_ranges().await
}
}

#[async_trait]
impl TransactionsImporter for CardanoTransactionsImporter {
async fn import(&self, up_to_beacon: BlockNumber) -> StdResult<()> {
task::block_in_place(move || {
let importer = self.clone();
task::spawn_blocking(move || {
Handle::current().block_on(async move {
self.import_transactions_and_block_ranges(up_to_beacon)
.await
importer.import_transactions(up_to_beacon).await?;
importer.import_block_ranges().await?;
Ok(())
})
})
.await
.with_context(|| "TransactionsImporter - worker thread crashed")?
}
}

#[cfg(test)]
mod tests {
use mithril_persistence::sqlite::SqliteConnectionPool;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;

use mockall::mock;

use mithril_common::cardano_block_scanner::{
Expand All @@ -208,6 +208,7 @@ mod tests {
use mithril_common::crypto_helper::MKTree;
use mithril_common::entities::{BlockNumber, BlockRangesSequence};
use mithril_persistence::database::repository::CardanoTransactionRepository;
use mithril_persistence::sqlite::SqliteConnectionPool;

use crate::database::test_helper::cardano_tx_db_connection;
use crate::test_tools::TestLogger;
Expand Down Expand Up @@ -652,7 +653,7 @@ mod tests {
);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn importing_twice_starting_with_nothing_in_a_real_db_should_yield_transactions_in_same_order(
) {
let blocks = vec![
Expand Down Expand Up @@ -689,7 +690,7 @@ mod tests {
assert_eq!(cold_imported_transactions, warm_imported_transactions);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn when_rollbackward_should_remove_transactions() {
let connection = cardano_tx_db_connection().unwrap();
let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(
Expand Down Expand Up @@ -732,7 +733,7 @@ mod tests {
assert_eq!(expected_remaining_transactions, stored_transactions);
}

#[tokio::test(flavor = "multi_thread")]
#[tokio::test]
async fn when_rollbackward_should_remove_block_ranges() {
let connection = cardano_tx_db_connection().unwrap();
let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(
Expand Down Expand Up @@ -804,4 +805,93 @@ mod tests {
.collect::<Vec<_>>()
);
}

#[tokio::test]
async fn test_import_is_non_blocking() {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
static MAX_COUNTER: usize = 25;
static WAIT_TIME: u64 = 50;

// Use a local set to ensure the counter task is not dispatched on a different thread
let local = task::LocalSet::new();
local
.run_until(async {
let counter_task = task::spawn_local(async {
while COUNTER.load(std::sync::atomic::Ordering::SeqCst) < MAX_COUNTER {
tokio::time::sleep(Duration::from_millis(1)).await;
COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
});

let importer = CardanoTransactionsImporter::new_for_test(
Arc::new(DumbBlockScanner::new()),
Arc::new(BlockingRepository {
wait_time: Duration::from_millis(WAIT_TIME),
}),
);

importer.import(100).await.unwrap();
counter_task.abort();
})
.await;

assert_eq!(
MAX_COUNTER,
COUNTER.load(std::sync::atomic::Ordering::SeqCst)
);

struct BlockingRepository {
wait_time: Duration,
}

impl BlockingRepository {
fn block_thread(&self) {
std::thread::sleep(self.wait_time);
}
}

#[async_trait]
impl TransactionStore for BlockingRepository {
async fn get_highest_beacon(&self) -> StdResult<Option<ChainPoint>> {
self.block_thread();
Ok(None)
}

async fn store_transactions(&self, _: Vec<CardanoTransaction>) -> StdResult<()> {
self.block_thread();
Ok(())
}

async fn get_block_interval_without_block_range_root(
&self,
) -> StdResult<Option<Range<BlockNumber>>> {
self.block_thread();
Ok(None)
}

async fn get_transactions_in_range(
&self,
_: Range<BlockNumber>,
) -> StdResult<Vec<CardanoTransaction>> {
self.block_thread();
Ok(vec![])
}

async fn store_block_range_roots(
&self,
_: Vec<(BlockRange, MKTreeNode)>,
) -> StdResult<()> {
self.block_thread();
Ok(())
}

async fn remove_rolled_back_transactions_and_block_range(
&self,
_: SlotNumber,
) -> StdResult<()> {
self.block_thread();
Ok(())
}
}
}
}
Loading

0 comments on commit 0774128

Please sign in to comment.