Skip to content

Commit

Permalink
Sync all groups at once (#981)
Browse files Browse the repository at this point in the history
* add code to batch group syncs

* small tweeks to sync all

* write a test for it

* fix up the lint

* rename

* cargo fmt

* make it faster

* add an involved test

* cargo fmt
  • Loading branch information
nplasterer authored Aug 21, 2024
1 parent 87b83fe commit 2ca8be1
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 3 deletions.
59 changes: 59 additions & 0 deletions bindings_ffi/src/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,14 @@ impl FfiConversations {
Ok(())
}

pub async fn sync_all_groups(&self) -> Result<(), GenericError> {
let inner = self.inner_client.as_ref();
let groups = inner.find_groups(None, None, None, None)?;

inner.sync_all_groups(groups).await?;
Ok(())
}

pub async fn list(
&self,
opts: FfiListConversationsOptions,
Expand Down Expand Up @@ -2236,6 +2244,57 @@ mod tests {
assert!(stream_messages.is_closed());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_can_sync_all_groups() {
let alix = new_test_client().await;
let bo = new_test_client().await;

for _i in 0..30 {
alix.conversations()
.create_group(
vec![bo.account_address.clone()],
FfiCreateGroupOptions::default(),
)
.await
.unwrap();
}

bo.conversations().sync().await.unwrap();
let alix_groups = alix
.conversations()
.list(FfiListConversationsOptions::default())
.await
.unwrap();

let alix_group1 = alix_groups[0].clone();
let alix_group5 = alix_groups[5].clone();
let bo_group1 = bo.group(alix_group1.id()).unwrap();
let bo_group5 = bo.group(alix_group5.id()).unwrap();

alix_group1.send("alix1".as_bytes().to_vec()).await.unwrap();
alix_group5.send("alix1".as_bytes().to_vec()).await.unwrap();

let bo_messages1 = bo_group1
.find_messages(FfiListMessagesOptions::default())
.unwrap();
let bo_messages5 = bo_group5
.find_messages(FfiListMessagesOptions::default())
.unwrap();
assert_eq!(bo_messages1.len(), 0);
assert_eq!(bo_messages5.len(), 0);

bo.conversations().sync_all_groups().await.unwrap();

let bo_messages1 = bo_group1
.find_messages(FfiListMessagesOptions::default())
.unwrap();
let bo_messages5 = bo_group5
.find_messages(FfiListMessagesOptions::default())
.unwrap();
assert_eq!(bo_messages1.len(), 1);
assert_eq!(bo_messages5.len(), 1);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn test_can_send_message_when_out_of_sync() {
let alix = new_test_client().await;
Expand Down
94 changes: 94 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, mem::Discriminant, sync::Arc};

use futures::{
future::join_all,
stream::{self, StreamExt},
Future,
};
Expand Down Expand Up @@ -587,6 +588,43 @@ where
Ok(groups)
}

pub async fn sync_all_groups(&self, groups: Vec<MlsGroup>) -> Result<(), GroupError> {
// Acquire a single connection to be reused
let conn = &self.store().conn()?;

let sync_futures: Vec<_> = groups
.into_iter()
.map(|group| {
let conn = conn.clone();
let mls_provider = self.mls_provider(conn.clone());

async move {
log::info!("[{}] syncing group", self.inbox_id());
log::info!(
"current epoch for [{}] in sync_all_groups() is Epoch: [{}]",
self.inbox_id(),
group.load_mls_group(mls_provider.clone()).unwrap().epoch()
);

group
.maybe_update_installations(conn.clone(), None, self)
.await?;

group.sync_with_conn(conn.clone(), self).await?;
Ok::<(), GroupError>(())
}
})
.collect();

// Run all sync operations concurrently
join_all(sync_futures)
.await
.into_iter()
.collect::<Result<(), _>>()?;

Ok(())
}

/**
* Validates a credential against the given installation public key
*
Expand Down Expand Up @@ -776,6 +814,62 @@ mod tests {
assert_eq!(duplicate_received_groups.len(), 0);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_sync_all_groups() {
let alix = ClientBuilder::new_test_client(&generate_local_wallet()).await;
let bo = ClientBuilder::new_test_client(&generate_local_wallet()).await;

let alix_bo_group1 = alix
.create_group(None, GroupMetadataOptions::default())
.unwrap();
let alix_bo_group2 = alix
.create_group(None, GroupMetadataOptions::default())
.unwrap();
alix_bo_group1
.add_members_by_inbox_id(&alix, vec![bo.inbox_id()])
.await
.unwrap();
alix_bo_group2
.add_members_by_inbox_id(&alix, vec![bo.inbox_id()])
.await
.unwrap();

let bob_received_groups = bo.sync_welcomes().await.unwrap();
assert_eq!(bob_received_groups.len(), 2);

let bo_groups = bo.find_groups(None, None, None, None).unwrap();
let bo_group1 = bo.group(alix_bo_group1.clone().group_id).unwrap();
let bo_messages1 = bo_group1
.find_messages(None, None, None, None, None)
.unwrap();
assert_eq!(bo_messages1.len(), 0);
let bo_group2 = bo.group(alix_bo_group2.clone().group_id).unwrap();
let bo_messages2 = bo_group2
.find_messages(None, None, None, None, None)
.unwrap();
assert_eq!(bo_messages2.len(), 0);
alix_bo_group1
.send_message(vec![1, 2, 3].as_slice(), &alix)
.await
.unwrap();
alix_bo_group2
.send_message(vec![1, 2, 3].as_slice(), &alix)
.await
.unwrap();

bo.sync_all_groups(bo_groups).await.unwrap();

let bo_messages1 = bo_group1
.find_messages(None, None, None, None, None)
.unwrap();
assert_eq!(bo_messages1.len(), 1);
let bo_group2 = bo.group(alix_bo_group2.clone().group_id).unwrap();
let bo_messages2 = bo_group2
.find_messages(None, None, None, None, None)
.unwrap();
assert_eq!(bo_messages2.len(), 1);
}

#[tokio::test]
async fn test_can_message() {
// let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await;
Expand Down
5 changes: 4 additions & 1 deletion xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,10 @@ impl MlsGroup {

// Load the stored MLS group from the OpenMLS provider's keystore
#[tracing::instrument(level = "trace", skip_all)]
fn load_mls_group(&self, provider: impl OpenMlsProvider) -> Result<OpenMlsGroup, GroupError> {
pub fn load_mls_group(
&self,
provider: impl OpenMlsProvider,
) -> Result<OpenMlsGroup, GroupError> {
let mls_group =
OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))
.map_err(|_| GroupError::GroupNotFound)?
Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/groups/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl MlsGroup {
}

#[tracing::instrument(level = "trace", skip(client, self, conn))]
pub(super) async fn sync_with_conn<ApiClient>(
pub async fn sync_with_conn<ApiClient>(
&self,
conn: DbConnection,
client: &Client<ApiClient>,
Expand Down Expand Up @@ -1000,7 +1000,7 @@ impl MlsGroup {
Ok(())
}

pub(super) async fn maybe_update_installations<ApiClient>(
pub async fn maybe_update_installations<ApiClient>(
&self,
conn: DbConnection,
update_interval: Option<i64>,
Expand Down

0 comments on commit 2ca8be1

Please sign in to comment.