From 7fa06c125e84f8f58b8b94a54739ed39d6978a1a Mon Sep 17 00:00:00 2001 From: Marc-Antoine Perennou Date: Sat, 18 Jul 2020 20:32:56 +0200 Subject: [PATCH] switch to multitask Signed-off-by: Marc-Antoine Perennou --- Cargo.toml | 16 +++++--- src/task/builder.rs | 14 +++---- src/task/executor.rs | 91 +++++++++++++++++++++++++++++++++++++++++ src/task/join_handle.rs | 21 ++++++---- src/task/mod.rs | 2 + 5 files changed, 123 insertions(+), 21 deletions(-) create mode 100644 src/task/executor.rs diff --git a/Cargo.toml b/Cargo.toml index 70ab8f6ad..c9944872d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,9 +29,9 @@ default = [ "blocking", "kv-log-macro", "log", + "multitask", "num_cpus", "pin-project-lite", - "smol", ] docs = ["attributes", "unstable", "default"] unstable = [ @@ -56,7 +56,7 @@ alloc = [ "futures-core/alloc", "pin-project-lite", ] -tokio02 = ["smol/tokio02"] +tokio02 = ["tokio"] [dependencies] async-attributes = { version = "1.1.1", optional = true } @@ -79,9 +79,9 @@ futures-timer = { version = "3.0.2", optional = true } surf = { version = "1.0.3", optional = true } [target.'cfg(not(target_os = "unknown"))'.dependencies] -async-io = { version = "0.1.2", optional = true } -blocking = { version = "0.4.6", optional = true } -smol = { version = "0.1.17", optional = true } +async-io = { version = "0.1.4", optional = true } +blocking = { version = "0.4.7", optional = true } +multitask = { version = "0.2.0", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] } @@ -91,6 +91,12 @@ futures-channel = { version = "0.3.4", optional = true } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] wasm-bindgen-test = "0.3.10" +[dependencies.tokio] +version = "0.2" +default-features = false +features = ["rt-threaded"] +optional = true + [dev-dependencies] femme = "1.3.0" rand = "0.7.3" diff --git a/src/task/builder.rs b/src/task/builder.rs index ab43d5587..0229e4140 100644 --- a/src/task/builder.rs +++ b/src/task/builder.rs @@ -7,7 +7,7 @@ use std::task::{Context, Poll}; use pin_project_lite::pin_project; use crate::io; -use crate::task::{JoinHandle, Task, TaskLocalsWrapper}; +use crate::task::{self, JoinHandle, Task, TaskLocalsWrapper}; /// Task builder that configures the settings of a new task. #[derive(Debug, Default)] @@ -61,9 +61,9 @@ impl Builder { }); let task = wrapped.tag.task().clone(); - let smol_task = smol::Task::spawn(wrapped).into(); + let handle = task::executor::spawn(wrapped); - Ok(JoinHandle::new(smol_task, task)) + Ok(JoinHandle::new(handle, task)) } /// Spawns a task locally with the configured settings. @@ -81,9 +81,9 @@ impl Builder { }); let task = wrapped.tag.task().clone(); - let smol_task = smol::Task::local(wrapped).into(); + let handle = task::executor::local(wrapped); - Ok(JoinHandle::new(smol_task, task)) + Ok(JoinHandle::new(handle, task)) } /// Spawns a task locally with the configured settings. @@ -166,8 +166,8 @@ impl Builder { unsafe { TaskLocalsWrapper::set_current(&wrapped.tag, || { let res = if should_run { - // The first call should use run. - smol::run(wrapped) + // The first call should run the executor + task::executor::run(wrapped) } else { blocking::block_on(wrapped) }; diff --git a/src/task/executor.rs b/src/task/executor.rs new file mode 100644 index 000000000..02fa4ca7e --- /dev/null +++ b/src/task/executor.rs @@ -0,0 +1,91 @@ +use std::cell::RefCell; +use std::future::Future; +use std::task::{Context, Poll}; + +static GLOBAL_EXECUTOR: once_cell::sync::Lazy = once_cell::sync::Lazy::new(multitask::Executor::new); + +struct Executor { + local_executor: multitask::LocalExecutor, + parker: async_io::parking::Parker, +} + +thread_local! { + static EXECUTOR: RefCell = RefCell::new({ + let (parker, unparker) = async_io::parking::pair(); + let local_executor = multitask::LocalExecutor::new(move || unparker.unpark()); + Executor { local_executor, parker } + }); +} + +pub(crate) fn spawn(future: F) -> multitask::Task +where + F: Future + Send + 'static, + T: Send + 'static, +{ + GLOBAL_EXECUTOR.spawn(future) +} + +#[cfg(feature = "unstable")] +pub(crate) fn local(future: F) -> multitask::Task +where + F: Future + 'static, + T: 'static, +{ + EXECUTOR.with(|executor| executor.borrow().local_executor.spawn(future)) +} + +pub(crate) fn run(future: F) -> T +where + F: Future, +{ + enter(|| EXECUTOR.with(|executor| { + let executor = executor.borrow(); + let unparker = executor.parker.unparker(); + let global_ticker = GLOBAL_EXECUTOR.ticker(move || unparker.unpark()); + let unparker = executor.parker.unparker(); + let waker = async_task::waker_fn(move || unparker.unpark()); + let cx = &mut Context::from_waker(&waker); + pin_utils::pin_mut!(future); + loop { + if let Poll::Ready(res) = future.as_mut().poll(cx) { + return res; + } + if let Ok(false) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| executor.local_executor.tick() || global_ticker.tick())) { + executor.parker.park(); + } + } + })) +} + +/// Enters the tokio context if the `tokio` feature is enabled. +fn enter(f: impl FnOnce() -> T) -> T { + #[cfg(not(feature = "tokio02"))] + return f(); + + #[cfg(feature = "tokio02")] + { + use std::cell::Cell; + use tokio::runtime::Runtime; + + thread_local! { + /// The level of nested `enter` calls we are in, to ensure that the outermost always + /// has a runtime spawned. + static NESTING: Cell = Cell::new(0); + } + + /// The global tokio runtime. + static RT: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| Runtime::new().expect("cannot initialize tokio")); + + NESTING.with(|nesting| { + let res = if nesting.get() == 0 { + nesting.replace(1); + RT.enter(f) + } else { + nesting.replace(nesting.get() + 1); + f() + }; + nesting.replace(nesting.get() - 1); + res + }) + } +} diff --git a/src/task/join_handle.rs b/src/task/join_handle.rs index 110b827e2..fd0d0fb77 100644 --- a/src/task/join_handle.rs +++ b/src/task/join_handle.rs @@ -18,7 +18,7 @@ pub struct JoinHandle { } #[cfg(not(target_os = "unknown"))] -type InnerHandle = async_task::JoinHandle; +type InnerHandle = multitask::Task; #[cfg(target_arch = "wasm32")] type InnerHandle = futures_channel::oneshot::Receiver; @@ -54,8 +54,7 @@ impl JoinHandle { #[cfg(not(target_os = "unknown"))] pub async fn cancel(mut self) -> Option { let handle = self.handle.take().unwrap(); - handle.cancel(); - handle.await + handle.cancel().await } /// Cancel this task. @@ -67,15 +66,19 @@ impl JoinHandle { } } +#[cfg(not(target_os = "unknown"))] +impl Drop for JoinHandle { + fn drop(&mut self) { + if let Some(handle) = self.handle.take() { + handle.detach(); + } + } +} + impl Future for JoinHandle { type Output = T; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(output) => { - Poll::Ready(output.expect("cannot await the result of a panicked task")) - } - } + Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) } } diff --git a/src/task/mod.rs b/src/task/mod.rs index ca0b92a02..9e025baf4 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -148,6 +148,8 @@ cfg_default! { mod block_on; mod builder; mod current; + #[cfg(not(target_os = "unknown"))] + mod executor; mod join_handle; mod sleep; #[cfg(not(target_os = "unknown"))]