diff --git a/Cargo.lock b/Cargo.lock index 8225e557141d1..110f6fcdc9a19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7227,6 +7227,7 @@ dependencies = [ "clap 4.0.11", "fdlimit", "futures", + "futures-timer", "libp2p", "log", "names", diff --git a/client/cli/Cargo.toml b/client/cli/Cargo.toml index 2f079a0c7c56f..fd84ff4d4574b 100644 --- a/client/cli/Cargo.toml +++ b/client/cli/Cargo.toml @@ -49,6 +49,7 @@ sp-version = { version = "5.0.0", path = "../../primitives/version" } [dev-dependencies] tempfile = "3.1.0" +futures-timer = "3.0.1" [features] default = ["rocksdb"] diff --git a/client/cli/src/runner.rs b/client/cli/src/runner.rs index c976c319708c2..d4191feddfa90 100644 --- a/client/cli/src/runner.rs +++ b/client/cli/src/runner.rs @@ -145,8 +145,14 @@ impl Runner { E: std::error::Error + Send + Sync + 'static + From, { self.print_node_infos(); + let mut task_manager = self.tokio_runtime.block_on(initialize(self.config))?; let res = self.tokio_runtime.block_on(main(task_manager.future().fuse())); + // We need to drop the task manager here to inform all tasks that they should shut down. + // + // This is important to be done before we instruct the tokio runtime to shutdown. Otherwise + // the tokio runtime will wait the full 60 seconds for all tasks to stop. + drop(task_manager); // Give all futures 60 seconds to shutdown, before tokio "leaks" them. self.tokio_runtime.shutdown_timeout(Duration::from_secs(60)); @@ -208,3 +214,208 @@ pub fn print_node_infos(config: &Configuration) { ); info!("⛓ Native runtime: {}", C::native_runtime_version(&config.chain_spec)); } + +#[cfg(test)] +mod tests { + use std::{ + path::PathBuf, + sync::atomic::{AtomicU64, Ordering}, + }; + + use sc_network::config::NetworkConfiguration; + use sc_service::{Arc, ChainType, GenericChainSpec, NoExtension}; + use sp_runtime::create_runtime_str; + use sp_version::create_apis_vec; + + use super::*; + + struct Cli; + + impl SubstrateCli for Cli { + fn author() -> String { + "test".into() + } + + fn impl_name() -> String { + "yep".into() + } + + fn impl_version() -> String { + "version".into() + } + + fn description() -> String { + "desc".into() + } + + fn support_url() -> String { + "no.pe".into() + } + + fn copyright_start_year() -> i32 { + 2042 + } + + fn load_spec( + &self, + _: &str, + ) -> std::result::Result, String> { + Err("nope".into()) + } + + fn native_runtime_version( + _: &Box, + ) -> &'static sp_version::RuntimeVersion { + const VERSION: sp_version::RuntimeVersion = sp_version::RuntimeVersion { + spec_name: create_runtime_str!("spec"), + impl_name: create_runtime_str!("name"), + authoring_version: 0, + spec_version: 0, + impl_version: 0, + apis: create_apis_vec!([]), + transaction_version: 2, + state_version: 0, + }; + + &VERSION + } + } + + fn create_runner() -> Runner { + let runtime = build_runtime().unwrap(); + + let runner = Runner::new( + Configuration { + impl_name: "spec".into(), + impl_version: "3".into(), + role: sc_service::Role::Authority, + tokio_handle: runtime.handle().clone(), + transaction_pool: Default::default(), + network: NetworkConfiguration::new_memory(), + keystore: sc_service::config::KeystoreConfig::InMemory, + keystore_remote: None, + database: sc_client_db::DatabaseSource::ParityDb { path: PathBuf::from("db") }, + trie_cache_maximum_size: None, + state_pruning: None, + blocks_pruning: sc_client_db::BlocksPruning::KeepAll, + chain_spec: Box::new(GenericChainSpec::from_genesis( + "test", + "test_id", + ChainType::Development, + || unimplemented!("Not required in tests"), + Vec::new(), + None, + None, + None, + None, + NoExtension::None, + )), + wasm_method: Default::default(), + wasm_runtime_overrides: None, + execution_strategies: Default::default(), + rpc_http: None, + rpc_ws: None, + rpc_ipc: None, + rpc_ws_max_connections: None, + rpc_cors: None, + rpc_methods: Default::default(), + rpc_max_payload: None, + rpc_max_request_size: None, + rpc_max_response_size: None, + rpc_id_provider: None, + rpc_max_subs_per_conn: None, + ws_max_out_buffer_capacity: None, + prometheus_config: None, + telemetry_endpoints: None, + default_heap_pages: None, + offchain_worker: Default::default(), + force_authoring: false, + disable_grandpa: false, + dev_key_seed: None, + tracing_targets: None, + tracing_receiver: Default::default(), + max_runtime_instances: 8, + announce_block: true, + base_path: None, + informant_output_format: Default::default(), + runtime_cache_size: 2, + }, + runtime, + ) + .unwrap(); + + runner + } + + #[test] + fn ensure_run_until_exit_informs_tasks_to_end() { + let runner = create_runner(); + + let counter = Arc::new(AtomicU64::new(0)); + let counter2 = counter.clone(); + + runner + .run_node_until_exit(move |cfg| async move { + let task_manager = TaskManager::new(cfg.tokio_handle.clone(), None).unwrap(); + let (sender, receiver) = futures::channel::oneshot::channel(); + + // We need to use `spawn_blocking` here so that we get a dedicated thread for our + // future. This is important for this test, as otherwise tokio can just "drop" the + // future. + task_manager.spawn_handle().spawn_blocking("test", None, async move { + let _ = sender.send(()); + loop { + counter2.fetch_add(1, Ordering::Relaxed); + futures_timer::Delay::new(Duration::from_millis(50)).await; + } + }); + + task_manager.spawn_essential_handle().spawn_blocking("test2", None, async { + // Let's stop this essential task directly when our other task started. + // It will signal that the task manager should end. + let _ = receiver.await; + }); + + Ok::<_, sc_service::Error>(task_manager) + }) + .unwrap_err(); + + let count = counter.load(Ordering::Relaxed); + + // Ensure that our counting task was running for less than 30 seconds. + // It should be directly killed, but for CI and whatever we are being a little bit more + // "relaxed". + assert!((count as u128) < (Duration::from_secs(30).as_millis() / 50)); + } + + /// This test ensures that `run_node_until_exit` aborts waiting for "stuck" tasks after 60 + /// seconds, aka doesn't wait until they are finished (which may never happen). + #[test] + fn ensure_run_until_exit_is_not_blocking_indefinitely() { + let runner = create_runner(); + + runner + .run_node_until_exit(move |cfg| async move { + let task_manager = TaskManager::new(cfg.tokio_handle.clone(), None).unwrap(); + let (sender, receiver) = futures::channel::oneshot::channel(); + + // We need to use `spawn_blocking` here so that we get a dedicated thread for our + // future. This future is more blocking code that will never end. + task_manager.spawn_handle().spawn_blocking("test", None, async move { + let _ = sender.send(()); + loop { + std::thread::sleep(Duration::from_secs(30)); + } + }); + + task_manager.spawn_essential_handle().spawn_blocking("test2", None, async { + // Let's stop this essential task directly when our other task started. + // It will signal that the task manager should end. + let _ = receiver.await; + }); + + Ok::<_, sc_service::Error>(task_manager) + }) + .unwrap_err(); + } +}