Skip to content

Commit

Permalink
Simplify aborting in streams (#989)
Browse files Browse the repository at this point in the history
## Summary

- When you receive a commit from a stream, we need to abort processing the message and sync instead. Previously we did that by decrypting the message and looking at the content to see what kind of message it was.
- There is a better way. There was a private method on the `PrivateMessageIn` that I have now made public that will let us see what kind of message it is _before_ decrypting.
- This should make streaming more performant and reliable, since we don't have to decrypt the message twice

## Related

xmtp/openmls#36
  • Loading branch information
neekolas authored Aug 23, 2024
1 parent 68d90b2 commit bda540d
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 46 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ futures = "0.3.30"
futures-core = "0.3.30"
hex = "0.4.3"
log = { version = "0.4" }
openmls = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5", default-features = false }
openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "cf42738018d093434c955a1b50a9de34cc12b8c5" }
openmls = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76", default-features = false }
openmls_basic_credential = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
openmls_traits = { git = "https://github.com/xmtp/openmls", rev = "87e7e257d8eb15d6662b104518becfc75ef6db76" }
pbjson = "0.6.0"
pbjson-types = "0.6.0"
prost = "^0.12"
Expand Down
12 changes: 6 additions & 6 deletions bindings_ffi/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions bindings_node/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion xmtp_api_http/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ pub async fn create_grpc_stream<
endpoint: String,
http_client: reqwest::Client,
) -> Result<BoxStream<'static, Result<R, Error>>, Error> {
log::info!("About to spawn stream");
let stream = async_stream::stream! {
log::debug!("Spawning grpc http stream");
log::info!("Spawning grpc http stream");
let request = http_client
.post(endpoint)
.json(&request)
Expand Down
53 changes: 50 additions & 3 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,8 +1266,9 @@ mod tests {
use xmtp_proto::xmtp::mls::message_contents::EncodedContent;

use crate::{
assert_logged,
assert_err, assert_logged,
builder::ClientBuilder,
client::MessageProcessingError,
codecs::{group_updated::GroupUpdatedCodec, ContentCodec},
groups::{
build_group_membership_extension,
Expand Down Expand Up @@ -3088,11 +3089,11 @@ mod tests {
alix1_group
.publish_intents(&alix1_provider, &alix1)
.await
.expect_err("Expected an error that publish was canceled");
.expect("Expect publish to be OK");
alix1_group
.publish_intents(&alix1_provider, &alix1)
.await
.expect_err("Expected an error that publish was canceled");
.expect("Expected publish to be OK");

// Now I am going to sync twice
alix1_group
Expand Down Expand Up @@ -3153,4 +3154,50 @@ mod tests {
.iter()
.any(|m| m.decrypted_message_bytes == "hi from alix1".as_bytes()));
}

#[tokio::test(flavor = "multi_thread", worker_threads = 5)]
async fn respect_allow_epoch_increment() {
let wallet = generate_local_wallet();
let client = ClientBuilder::new_test_client(&wallet).await;

let group = client
.create_group(None, GroupMetadataOptions::default())
.unwrap();

let _client_2 = ClientBuilder::new_test_client(&wallet).await;

// Sync the group to get the message adding client_2 published to the network
group.sync(&client).await.unwrap();

// Retrieve the envelope for the commit from the network
let messages = client
.api_client
.query_group_messages(group.group_id.clone(), None)
.await
.unwrap();

let first_envelope = messages.first().unwrap();

let Some(xmtp_proto::xmtp::mls::api::v1::group_message::Version::V1(first_message)) =
first_envelope.clone().version
else {
panic!("wrong message format")
};
let provider = client.mls_provider().unwrap();
let mut openmls_group = group.load_mls_group(&provider).unwrap();
let process_result = group
.process_message(
&client,
&mut openmls_group,
&provider,
&first_message,
false,
)
.await;

assert_err!(
process_result,
MessageProcessingError::EpochIncrementNotAllowed
);
}
}
14 changes: 12 additions & 2 deletions xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ impl MlsGroup {
);

if let Some(GroupError::ReceiveError(_)) = process_result.as_ref().err() {
self.sync_with_conn(&client.mls_provider()?, &client)
.await?;
// Swallow errors here, since another process may have successfully saved the message
// to the DB
match self.sync_with_conn(&client.mls_provider()?, &client).await {
Ok(_) => {
log::debug!("Sync triggered by streamed message successful")
}
Err(err) => {
log::warn!("Sync triggered by streamed message failed: {}", err);
}
};
} else if process_result.is_err() {
log::error!("Process stream entry {:?}", process_result.err());
}
Expand Down Expand Up @@ -309,6 +317,8 @@ mod tests {
});
// just to make sure stream is started
let _ = start_rx.await;
// Adding in a sleep, since the HTTP API client may acknowledge requests before they are ready
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

amal_group
.add_members_by_inbox_id(&amal, vec![bola.inbox_id()])
Expand Down
22 changes: 8 additions & 14 deletions xmtp_mls/src/groups/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use log::debug;
use openmls::{
credentials::BasicCredential,
extensions::Extensions,
framing::ProtocolMessage,
framing::{ContentType, ProtocolMessage},
group::{GroupEpoch, StagedCommit},
prelude::{
tls_codec::{Deserialize, Serialize},
Expand Down Expand Up @@ -266,7 +266,6 @@ impl MlsGroup {
provider: &XmtpOpenMlsProvider,
message: ProtocolMessage,
envelope_timestamp_ns: u64,
allow_epoch_increment: bool,
) -> Result<IntentState, MessageProcessingError> {
if intent.state == IntentState::Committed {
return Ok(IntentState::Committed);
Expand All @@ -290,9 +289,6 @@ impl MlsGroup {
| IntentKind::UpdateAdminList
| IntentKind::MetadataUpdate
| IntentKind::UpdatePermission => {
if !allow_epoch_increment {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}
if let Some(published_in_epoch) = intent.published_in_epoch {
let published_in_epoch_u64 = published_in_epoch as u64;
let group_epoch_u64 = group_epoch.as_u64();
Expand Down Expand Up @@ -376,7 +372,6 @@ impl MlsGroup {
provider: &XmtpOpenMlsProvider,
message: PrivateMessageIn,
envelope_timestamp_ns: u64,
allow_epoch_increment: bool,
) -> Result<(), MessageProcessingError> {
let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) =
Expand Down Expand Up @@ -548,9 +543,6 @@ impl MlsGroup {
// intentionally left blank.
}
ProcessedMessageContent::StagedCommitMessage(staged_commit) => {
if !allow_epoch_increment {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}
log::info!(
"[{}] received staged commit. Merging and clearing any pending commits",
self.context.inbox_id()
Expand Down Expand Up @@ -600,6 +592,10 @@ impl MlsGroup {
)),
}?;

if !allow_epoch_increment && message.content_type() == ContentType::Commit {
return Err(MessageProcessingError::EpochIncrementNotAllowed);
}

let intent = provider
.conn_ref()
.find_group_intent_by_payload_hash(sha256(envelope.data.as_slice()));
Expand All @@ -622,7 +618,6 @@ impl MlsGroup {
provider,
message.into(),
envelope.created_ns,
allow_epoch_increment,
)
.await?
{
Expand Down Expand Up @@ -654,7 +649,6 @@ impl MlsGroup {
provider,
message,
envelope.created_ns,
allow_epoch_increment,
)
.await
}
Expand Down Expand Up @@ -875,8 +869,8 @@ impl MlsGroup {
intent.kind
);
if has_staged_commit {
log::info!("Canceling all further publishes, since a commit was found");
return Err(GroupError::PublishCancelled);
log::info!("Commit sent. Stopping further publishes for this round");
return Ok(());
}
}
Ok(None) => {
Expand Down Expand Up @@ -1011,7 +1005,7 @@ impl MlsGroup {
}
}

#[tracing::instrument(level = "trace", skip(conn, client))]
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) async fn post_commit<ApiClient>(
&self,
conn: &DbConnection,
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ mod tests {
#[macro_export]
macro_rules! assert_err {
( $x:expr , $y:pat $(,)? ) => {
assert!(matches!($x, Err($y)));
assert!(matches!($x, Err($y)))
};

( $x:expr, $y:pat $(,)?, $($msg:tt)+) => {{
Expand Down
Loading

0 comments on commit bda540d

Please sign in to comment.