From 2ca8be17caf3a14f34f0daf128577637eed1e888 Mon Sep 17 00:00:00 2001 From: Naomi Plasterer Date: Wed, 21 Aug 2024 13:17:12 -0600 Subject: [PATCH] Sync all groups at once (#981) * add code to batch group syncs * small tweeks to sync all * write a test for it * fix up the lint * rename * cargo fmt * make it faster * add an involved test * cargo fmt --- bindings_ffi/src/mls.rs | 59 +++++++++++++++++++++++ xmtp_mls/src/client.rs | 94 +++++++++++++++++++++++++++++++++++++ xmtp_mls/src/groups/mod.rs | 5 +- xmtp_mls/src/groups/sync.rs | 4 +- 4 files changed, 159 insertions(+), 3 deletions(-) diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index ec827873e..4deb0fc8d 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -752,6 +752,14 @@ impl FfiConversations { Ok(()) } + pub async fn sync_all_groups(&self) -> Result<(), GenericError> { + let inner = self.inner_client.as_ref(); + let groups = inner.find_groups(None, None, None, None)?; + + inner.sync_all_groups(groups).await?; + Ok(()) + } + pub async fn list( &self, opts: FfiListConversationsOptions, @@ -2236,6 +2244,57 @@ mod tests { assert!(stream_messages.is_closed()); } + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + async fn test_can_sync_all_groups() { + let alix = new_test_client().await; + let bo = new_test_client().await; + + for _i in 0..30 { + alix.conversations() + .create_group( + vec![bo.account_address.clone()], + FfiCreateGroupOptions::default(), + ) + .await + .unwrap(); + } + + bo.conversations().sync().await.unwrap(); + let alix_groups = alix + .conversations() + .list(FfiListConversationsOptions::default()) + .await + .unwrap(); + + let alix_group1 = alix_groups[0].clone(); + let alix_group5 = alix_groups[5].clone(); + let bo_group1 = bo.group(alix_group1.id()).unwrap(); + let bo_group5 = bo.group(alix_group5.id()).unwrap(); + + alix_group1.send("alix1".as_bytes().to_vec()).await.unwrap(); + alix_group5.send("alix1".as_bytes().to_vec()).await.unwrap(); + + let bo_messages1 = bo_group1 + .find_messages(FfiListMessagesOptions::default()) + .unwrap(); + let bo_messages5 = bo_group5 + .find_messages(FfiListMessagesOptions::default()) + .unwrap(); + assert_eq!(bo_messages1.len(), 0); + assert_eq!(bo_messages5.len(), 0); + + bo.conversations().sync_all_groups().await.unwrap(); + + let bo_messages1 = bo_group1 + .find_messages(FfiListMessagesOptions::default()) + .unwrap(); + let bo_messages5 = bo_group5 + .find_messages(FfiListMessagesOptions::default()) + .unwrap(); + assert_eq!(bo_messages1.len(), 1); + assert_eq!(bo_messages5.len(), 1); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] async fn test_can_send_message_when_out_of_sync() { let alix = new_test_client().await; diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 7beb42e80..225d33c2f 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, mem::Discriminant, sync::Arc}; use futures::{ + future::join_all, stream::{self, StreamExt}, Future, }; @@ -587,6 +588,43 @@ where Ok(groups) } + pub async fn sync_all_groups(&self, groups: Vec) -> Result<(), GroupError> { + // Acquire a single connection to be reused + let conn = &self.store().conn()?; + + let sync_futures: Vec<_> = groups + .into_iter() + .map(|group| { + let conn = conn.clone(); + let mls_provider = self.mls_provider(conn.clone()); + + async move { + log::info!("[{}] syncing group", self.inbox_id()); + log::info!( + "current epoch for [{}] in sync_all_groups() is Epoch: [{}]", + self.inbox_id(), + group.load_mls_group(mls_provider.clone()).unwrap().epoch() + ); + + group + .maybe_update_installations(conn.clone(), None, self) + .await?; + + group.sync_with_conn(conn.clone(), self).await?; + Ok::<(), GroupError>(()) + } + }) + .collect(); + + // Run all sync operations concurrently + join_all(sync_futures) + .await + .into_iter() + .collect::>()?; + + Ok(()) + } + /** * Validates a credential against the given installation public key * @@ -776,6 +814,62 @@ mod tests { assert_eq!(duplicate_received_groups.len(), 0); } + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_sync_all_groups() { + let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await; + let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await; + + let alix_bo_group1 = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + let alix_bo_group2 = alix + .create_group(None, GroupMetadataOptions::default()) + .unwrap(); + alix_bo_group1 + .add_members_by_inbox_id(&alix, vec![bo.inbox_id()]) + .await + .unwrap(); + alix_bo_group2 + .add_members_by_inbox_id(&alix, vec![bo.inbox_id()]) + .await + .unwrap(); + + let bob_received_groups = bo.sync_welcomes().await.unwrap(); + assert_eq!(bob_received_groups.len(), 2); + + let bo_groups = bo.find_groups(None, None, None, None).unwrap(); + let bo_group1 = bo.group(alix_bo_group1.clone().group_id).unwrap(); + let bo_messages1 = bo_group1 + .find_messages(None, None, None, None, None) + .unwrap(); + assert_eq!(bo_messages1.len(), 0); + let bo_group2 = bo.group(alix_bo_group2.clone().group_id).unwrap(); + let bo_messages2 = bo_group2 + .find_messages(None, None, None, None, None) + .unwrap(); + assert_eq!(bo_messages2.len(), 0); + alix_bo_group1 + .send_message(vec![1, 2, 3].as_slice(), &alix) + .await + .unwrap(); + alix_bo_group2 + .send_message(vec![1, 2, 3].as_slice(), &alix) + .await + .unwrap(); + + bo.sync_all_groups(bo_groups).await.unwrap(); + + let bo_messages1 = bo_group1 + .find_messages(None, None, None, None, None) + .unwrap(); + assert_eq!(bo_messages1.len(), 1); + let bo_group2 = bo.group(alix_bo_group2.clone().group_id).unwrap(); + let bo_messages2 = bo_group2 + .find_messages(None, None, None, None, None) + .unwrap(); + assert_eq!(bo_messages2.len(), 1); + } + #[tokio::test] async fn test_can_message() { // let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index b842aa2ca..1645769b2 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -253,7 +253,10 @@ impl MlsGroup { // Load the stored MLS group from the OpenMLS provider's keystore #[tracing::instrument(level = "trace", skip_all)] - fn load_mls_group(&self, provider: impl OpenMlsProvider) -> Result { + pub fn load_mls_group( + &self, + provider: impl OpenMlsProvider, + ) -> Result { let mls_group = OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id)) .map_err(|_| GroupError::GroupNotFound)? diff --git a/xmtp_mls/src/groups/sync.rs b/xmtp_mls/src/groups/sync.rs index cf05295d0..0723de47c 100644 --- a/xmtp_mls/src/groups/sync.rs +++ b/xmtp_mls/src/groups/sync.rs @@ -98,7 +98,7 @@ impl MlsGroup { } #[tracing::instrument(level = "trace", skip(client, self, conn))] - pub(super) async fn sync_with_conn( + pub async fn sync_with_conn( &self, conn: DbConnection, client: &Client, @@ -1000,7 +1000,7 @@ impl MlsGroup { Ok(()) } - pub(super) async fn maybe_update_installations( + pub async fn maybe_update_installations( &self, conn: DbConnection, update_interval: Option,