diff --git a/tokio/src/sync/tests/loom_watch.rs b/tokio/src/sync/tests/loom_watch.rs index c575b5b66c5..51589cd8042 100644 --- a/tokio/src/sync/tests/loom_watch.rs +++ b/tokio/src/sync/tests/loom_watch.rs @@ -2,6 +2,7 @@ use crate::sync::watch; use loom::future::block_on; use loom::thread; +use std::sync::Arc; #[test] fn smoke() { @@ -34,3 +35,56 @@ fn smoke() { th.join().unwrap(); }) } + +#[test] +fn wait_for_test() { + loom::model(move || { + let (tx, mut rx) = watch::channel(false); + + let tx_arc = Arc::new(tx); + let tx1 = tx_arc.clone(); + let tx2 = tx_arc.clone(); + + let th1 = thread::spawn(move || { + for _ in 0..2 { + tx1.send_modify(|_x| {}); + } + }); + + let th2 = thread::spawn(move || { + tx2.send(true).unwrap(); + }); + + assert_eq!(*block_on(rx.wait_for(|x| *x)).unwrap(), true); + + th1.join().unwrap(); + th2.join().unwrap(); + }); +} + +#[test] +fn wait_for_returns_correct_value() { + loom::model(move || { + let (tx, mut rx) = watch::channel(0); + + let jh = thread::spawn(move || { + tx.send(1).unwrap(); + tx.send(2).unwrap(); + tx.send(3).unwrap(); + }); + + // Stop at the first value we are called at. + let mut stopped_at = usize::MAX; + let returned = *block_on(rx.wait_for(|x| { + stopped_at = *x; + true + })) + .unwrap(); + + // Check that it returned the same value as the one we returned + // `true` for. + assert_eq!(stopped_at, returned); + + jh.join().unwrap(); + }); +} diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 9ca75979b99..449711ad75c 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -595,18 +595,93 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { + changed_impl(&self.shared, &mut self.version).await + } + + /// Waits for a value that satisifes the provided condition. + /// + /// This method will call the provided closure whenever something is sent on + /// the channel. Once the closure returns `true`, this method will return a + /// reference to the value that was passed to the closure. + /// + /// Before `wait_for` starts waiting for changes, it will call the closure + /// on the current value. If the closure returns `true` when given the + /// current value, then `wait_for` will immediately return a reference to + /// the current value. This is the case even if the current value is already + /// considered seen. + /// + /// The watch channel only keeps track of the most recent value, so if + /// several messages are sent faster than `wait_for` is able to call the + /// closure, then it may skip some updates. Whenever the closure is called, + /// it will be called with the most recent value. + /// + /// When this function returns, the value that was passed to the closure + /// when it returned `true` will be considered seen. + /// + /// If the channel is closed, then `wait_for` will return a `RecvError`. + /// Once this happens, no more messages can ever be sent on the channel. + /// When an error is returned, it is guaranteed that the closure has been + /// called on the last value, and that it returned `false` for that value. + /// (If the closure returned `true`, then the last value would have been + /// returned instead of the error.) + /// + /// Like the `borrow` method, the returned borrow holds a read lock on the + /// inner value. This means that long-lived borrows could cause the producer + /// half to block. It is recommended to keep the borrow as short-lived as + /// possible. See the documentation of `borrow` for more information on + /// this. + /// + /// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// + /// async fn main() { + /// let (tx, _rx) = watch::channel("hello"); + /// + /// tx.send("goodbye").unwrap(); + /// + /// // here we subscribe to a second receiver + /// // now in case of using `changed` we would have + /// // to first check the current value and then wait + /// // for changes or else `changed` would hang. + /// let mut rx2 = tx.subscribe(); + /// + /// // in place of changed we have use `wait_for` + /// // which would automatically check the current value + /// // and wait for changes until the closure returns true. + /// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok()); + /// assert_eq!(*rx2.borrow(), "goodbye"); + /// } + /// ``` + pub async fn wait_for( + &mut self, + mut f: impl FnMut(&T) -> bool, + ) -> Result, error::RecvError> { + let mut closed = false; loop { - // In order to avoid a race condition, we first request a notification, - // **then** check the current value's version. If a new version exists, - // the notification request is dropped. - let notified = self.shared.notify_rx.notified(); + { + let inner = self.shared.value.read().unwrap(); - if let Some(ret) = maybe_changed(&self.shared, &mut self.version) { - return ret; + let new_version = self.shared.state.load().version(); + let has_changed = self.version != new_version; + self.version = new_version; + + if (!closed || has_changed) && f(&inner) { + return Ok(Ref { inner, has_changed }); + } } - notified.await; - // loop around again in case the wake-up was spurious + if closed { + return Err(error::RecvError(())); + } + + // Wait for the value to change. + closed = changed_impl(&self.shared, &mut self.version).await.is_err(); } } @@ -655,6 +730,25 @@ fn maybe_changed( None } +async fn changed_impl( + shared: &Shared, + version: &mut Version, +) -> Result<(), error::RecvError> { + loop { + // In order to avoid a race condition, we first request a notification, + // **then** check the current value's version. If a new version exists, + // the notification request is dropped. + let notified = shared.notify_rx.notified(); + + if let Some(ret) = maybe_changed(shared, version) { + return ret; + } + + notified.await; + // loop around again in case the wake-up was spurious + } +} + impl Clone for Receiver { fn clone(&self) -> Self { let version = self.version;