Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 63a586d
Author: sinu.eth <65924192+sinui0@users.noreply.github.com>
Date:   Mon Jun 10 23:09:01 2024 -0700

    feat(mpz-ot): impl more OT traits on shared KOS (#153)

commit 5a20d29
Author: sinu.eth <65924192+sinui0@users.noreply.github.com>
Date:   Mon Jun 10 23:07:17 2024 -0700

    feat(mpz-common): async sync primitives (#152)

    * feat(mpz-common): async sync primitives

    * update syncer test

    * add unsync lock method

    * update async syncer test

commit 86eebbf
Author: sinu.eth <65924192+sinui0@users.noreply.github.com>
Date:   Fri Jun 7 09:28:05 2024 -0700

    feat(mpz-common): add type alias for test st executor (#154)
  • Loading branch information
sinui0 committed Jun 11, 2024
1 parent b28f69d commit 0d39556
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 47 deletions.
1 change: 1 addition & 0 deletions crates/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ tokio = "1.23"
tokio-util = "0.7"
scoped-futures = "0.1.3"
pollster = "0.3"
pin-project-lite = "0.2"

# serialization
ark-serialize = "0.4"
Expand Down
4 changes: 3 additions & 1 deletion crates/mpz-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"

[features]
default = ["sync"]
sync = []
sync = ["tokio/sync"]
test-utils = ["uid-mux/test-utils"]
ideal = []
rayon = ["dep:rayon"]
Expand All @@ -16,6 +16,7 @@ mpz-core.workspace = true

futures.workspace = true
async-trait.workspace = true
pin-project-lite.workspace = true
scoped-futures.workspace = true
thiserror.workspace = true
serio.workspace = true
Expand All @@ -24,6 +25,7 @@ serde = { workspace = true, features = ["derive"] }
pollster.workspace = true
rayon = { workspace = true, optional = true }
cfg-if.workspace = true
tokio = { workspace = true, optional = true, default-features = false }

[dev-dependencies]
tokio = { workspace = true, features = [
Expand Down
7 changes: 4 additions & 3 deletions crates/mpz-common/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ mod test_utils {

use super::*;

/// Test single-threaded executor.
pub type TestSTExecutor = STExecutor<MemoryDuplex>;

/// Creates a pair of single-threaded executors with memory I/O channels.
pub fn test_st_executor(
io_buffer: usize,
) -> (STExecutor<MemoryDuplex>, STExecutor<MemoryDuplex>) {
pub fn test_st_executor(io_buffer: usize) -> (TestSTExecutor, TestSTExecutor) {
let (io_0, io_1) = duplex(io_buffer);

(STExecutor::new(io_0), STExecutor::new(io_1))
Expand Down
109 changes: 109 additions & 0 deletions crates/mpz-common/src/sync/async_mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//! Synchronized async mutex.

use pollster::FutureExt;
use tokio::sync::{Mutex as TokioMutex, MutexGuard};

use crate::{
context::Context,
sync::{AsyncSyncer, MutexError},
};

/// A mutex which synchronizes exclusive access to a resource across logical threads.
///
/// There are two configurations for a mutex, either as a leader or as a follower.
///
/// **Leader**
///
/// A leader mutex is the authority on the order in which threads can acquire a lock. When a
/// thread acquires a lock, it broadcasts a message to all follower mutexes, which then enforce
/// that this order is preserved.
///
/// **Follower**
///
/// A follower mutex waits for messages from the leader mutex to inform it of the order in which
/// threads can acquire a lock.
#[derive(Debug)]
pub struct AsyncMutex<T> {
inner: TokioMutex<T>,
syncer: AsyncSyncer,
}

impl<T> AsyncMutex<T> {
/// Creates a new leader mutex.
///
/// # Arguments
///
/// * `value` - The value protected by the mutex.
pub fn new_leader(value: T) -> Self {
Self {
inner: TokioMutex::new(value),
syncer: AsyncSyncer::new_leader(),
}
}

/// Creates a new follower mutex.
///
/// # Arguments
///
/// * `value` - The value protected by the mutex.
pub fn new_follower(value: T) -> Self {
Self {
inner: TokioMutex::new(value),
syncer: AsyncSyncer::new_follower(),
}
}

/// Returns a lock on the mutex.
pub async fn lock<Ctx: Context>(&self, ctx: &mut Ctx) -> Result<MutexGuard<'_, T>, MutexError> {
self.syncer
.sync(ctx.io_mut(), self.inner.lock())
.await
.map_err(MutexError::from)
}

/// Returns an unsynchronized blocking lock on the mutex.
///
/// # Warning
///
/// Do not use this method unless you are certain that the way you're mutating the state does
/// not require synchronization. Also, don't hold this lock across await points it will cause
/// deadlocks.
pub fn blocking_lock_unsync(&self) -> MutexGuard<'_, T> {
self.inner.lock().block_on()
}

/// Returns the inner value, consuming the mutex.
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;

#[test]
fn test_async_mutex() {
let leader_mutex = Arc::new(AsyncMutex::new_leader(()));
let follower_mutex = Arc::new(AsyncMutex::new_follower(()));

let (mut ctx_a, mut ctx_b) = crate::executor::test_st_executor(8);

futures::executor::block_on(async {
futures::join!(
async {
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
drop(leader_mutex.lock(&mut ctx_a).await.unwrap());
},
async {
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
drop(follower_mutex.lock(&mut ctx_b).await.unwrap());
},
);
});
}
}
239 changes: 239 additions & 0 deletions crates/mpz-common/src/sync/async_syncer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use std::{
collections::HashMap,
pin::Pin,
sync::{Arc, Mutex as StdMutex},
task::{ready, Context as StdContext, Poll, Waker},
};

use futures::{future::poll_fn, Future};
use serio::{stream::IoStreamExt, IoDuplex};
use tokio::sync::Mutex;

use crate::sync::{SyncError, Ticket};

/// An async version of [`Syncer`](crate::sync::Syncer).
#[derive(Debug, Clone)]
pub struct AsyncSyncer(SyncerInner);

impl AsyncSyncer {
/// Creates a new leader.
pub fn new_leader() -> Self {
Self(SyncerInner::Leader(Leader::default()))
}

/// Creates a new follower.
pub fn new_follower() -> Self {
Self(SyncerInner::Follower(Follower::default()))
}

/// Synchronizes the order of execution across logical threads.
///
/// # Arguments
///
/// * `io` - The I/O channel of the logical thread.
/// * `fut` - The future to await.
pub async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
match &self.0 {
SyncerInner::Leader(leader) => leader.sync(io, fut).await,
SyncerInner::Follower(follower) => follower.sync(io, fut).await,
}
}
}

#[derive(Debug, Clone)]
enum SyncerInner {
Leader(Leader),
Follower(Follower),
}

#[derive(Debug, Default, Clone)]
struct Leader {
tick: Arc<Mutex<Ticket>>,
}

impl Leader {
async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
let mut io = Pin::new(io);
poll_fn(|cx| io.as_mut().poll_ready(cx)).await?;
let (output, tick) = {
let mut tick_lock = self.tick.lock().await;
let output = fut.await;
let tick = tick_lock.increment_in_place();
(output, tick)
};
io.start_send(tick)?;
Ok(output)
}
}

#[derive(Debug, Default, Clone)]
struct Follower {
queue: Arc<StdMutex<Queue>>,
}

impl Follower {
async fn sync<Io: IoDuplex + Unpin, Fut>(
&self,
io: &mut Io,
fut: Fut,
) -> Result<Fut::Output, SyncError>
where
Fut: Future,
{
let tick = io.expect_next().await?;
Ok(Wait::new(&self.queue, tick, fut).await)
}
}

#[derive(Debug, Default)]
struct Queue {
// The current ticket.
tick: Ticket,
// Tasks waiting for their ticket to be accepted.
waiting: HashMap<Ticket, Waker>,
}

impl Queue {
// Wakes up the next waiting task.
fn wake_next(&mut self) {
if let Some(waker) = self.waiting.remove(&self.tick) {
waker.wake();
}
}
}

pin_project_lite::pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Wait<'a, Fut> {
queue: &'a StdMutex<Queue>,
tick: Ticket,
#[pin]
fut: Fut,
}
}

impl<'a, Fut> Wait<'a, Fut> {
fn new(queue: &'a StdMutex<Queue>, tick: Ticket, fut: Fut) -> Self {
Self { queue, tick, fut }
}
}

impl<'a, Fut> Future for Wait<'a, Fut>
where
Fut: Future,
{
type Output = Fut::Output;

fn poll(self: Pin<&mut Self>, cx: &mut StdContext<'_>) -> Poll<Self::Output> {
let mut queue_lock = self.queue.lock().unwrap();
if queue_lock.tick == self.tick {
let this = self.project();
let output = ready!(this.fut.poll(cx));
queue_lock.tick.increment_in_place();
queue_lock.wake_next();
Poll::Ready(output)
} else {
queue_lock.waiting.insert(self.tick, cx.waker().clone());
Poll::Pending
}
}
}

#[cfg(test)]
mod tests {
use futures::{executor::block_on, poll};
use serio::channel::duplex;

use super::*;

#[test]
fn test_syncer() {
let (mut io_0a, mut io_0b) = duplex(1);
let (mut io_1a, mut io_1b) = duplex(1);
let (mut io_2a, mut io_2b) = duplex(1);

let syncer_a = AsyncSyncer::new_leader();
let syncer_b = AsyncSyncer::new_follower();

let log_a = Arc::new(Mutex::new(Vec::new()));
let log_b = Arc::new(Mutex::new(Vec::new()));

block_on(async {
syncer_a
.sync(&mut io_0a, async {
let mut log = log_a.lock().await;
log.push(0);
})
.await
.unwrap();
syncer_a
.sync(&mut io_1a, async {
let mut log = log_a.lock().await;
log.push(1);
})
.await
.unwrap();
syncer_a
.sync(&mut io_2a, async {
let mut log = log_a.lock().await;
log.push(2);
})
.await
.unwrap();
});

let mut fut_a = Box::pin(syncer_b.sync(&mut io_2b, async {
let mut log = log_b.lock().await;
log.push(2);
}));

let mut fut_b = Box::pin(syncer_b.sync(&mut io_0b, async {
let mut log = log_b.lock().await;
log.push(0);
}));

let mut fut_c = Box::pin(syncer_b.sync(&mut io_1b, async {
let mut log = log_b.lock().await;
log.push(1);
}));

block_on(async move {
// Poll out of order.
assert!(poll!(&mut fut_a).is_pending());
assert!(poll!(&mut fut_c).is_pending());
assert!(poll!(&mut fut_b).is_ready());
assert!(poll!(&mut fut_c).is_ready());
assert!(poll!(&mut fut_a).is_ready());
});

let log_a = Arc::into_inner(log_a).unwrap().into_inner();
let log_b = Arc::into_inner(log_b).unwrap().into_inner();

assert_eq!(log_a, log_b);
}

#[test]
fn test_syncer_is_send() {
let (mut io, _) = duplex(1);
let syncer = AsyncSyncer::new_leader();

fn is_send<T: Send>(_: T) {}

is_send(syncer.sync(&mut io, async {}));
}
}
Loading

0 comments on commit 0d39556

Please sign in to comment.