diff --git a/Cargo.lock b/Cargo.lock index dcfbe4341..93f528a86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1799,18 +1799,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "flume" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "spin 0.9.8", -] - [[package]] name = "fnv" version = "1.0.7" @@ -2646,15 +2634,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -3084,15 +3063,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.12" @@ -3902,7 +3872,7 @@ checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1" dependencies = [ "bytes", "heck", - "itertools 0.12.1", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -3922,7 +3892,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.72", @@ -4764,17 +4734,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "smart-default" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "socket2" version = "0.5.7" @@ -4810,9 +4769,6 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] [[package]] name = "spki" @@ -5531,6 +5487,17 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -6243,8 +6210,6 @@ name = "xmtp_api_http" version = "0.0.1" dependencies = [ "async-stream", - "async-trait", - "bytes", "futures", "log", "reqwest 0.12.5", @@ -6263,7 +6228,6 @@ version = "0.1.0" dependencies = [ "clap", "ethers", - "ethers-core", "femme", "futures", "hex", @@ -6291,7 +6255,6 @@ dependencies = [ "curve25519-dalek", "ecdsa 0.16.9", "ethers", - "ethers-core", "getrandom", "hex", "k256 0.13.3", @@ -6316,7 +6279,6 @@ dependencies = [ "ed25519", "ed25519-dalek", "ethers", - "ethers-core", "futures", "hex", "log", @@ -6347,7 +6309,6 @@ dependencies = [ "anyhow", "async-barrier", "async-stream", - "async-trait", "bincode", "chrono", "criterion", @@ -6356,8 +6317,6 @@ dependencies = [ "diesel_migrations", "ed25519-dalek", "ethers", - "ethers-core", - "flume", "futures", "hex", "indicatif", @@ -6377,7 +6336,6 @@ dependencies = [ "serde", "serde_json", "sha2 0.10.8", - "smart-default", "tempfile", "thiserror", "tls_codec", @@ -6389,6 +6347,7 @@ dependencies = [ "tracing-log", "tracing-subscriber", "tracing-test", + "trait-variant", "xmtp_api_grpc", "xmtp_api_http", "xmtp_cryptography", @@ -6401,7 +6360,6 @@ dependencies = [ name = "xmtp_proto" version = "0.0.1" dependencies = [ - "async-trait", "futures", "futures-core", "openmls", @@ -6412,6 +6370,7 @@ dependencies = [ "prost-types", "serde", "tonic", + "trait-variant", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index cefde727d..947953a45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,12 @@ version = "0.0.1" anyhow = "1.0" async-stream = "0.3" async-trait = "0.1.77" +trait-variant = "0.1.2" chrono = "0.4.38" ctor = "0.2" ed25519 = "2.2.3" ed25519-dalek = "2.1.1" ethers = "2.0.11" -ethers-core = "2.0.4" futures = "0.3.30" futures-core = "0.3.30" getrandom = { version = "0.2", default-features = false } diff --git a/bindings_ffi/Cargo.lock b/bindings_ffi/Cargo.lock index 5a3700f04..a9ee5a85d 100644 --- a/bindings_ffi/Cargo.lock +++ b/bindings_ffi/Cargo.lock @@ -3502,7 +3502,7 @@ checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.11.0", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -3522,7 +3522,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.48", @@ -4359,17 +4359,6 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" -[[package]] -name = "smart-default" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.48", -] - [[package]] name = "smawk" version = "0.3.2" @@ -5040,6 +5029,17 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -5772,7 +5772,6 @@ dependencies = [ "curve25519-dalek", "ecdsa 0.16.9", "ethers", - "ethers-core", "getrandom", "hex", "k256 0.13.3", @@ -5795,7 +5794,6 @@ dependencies = [ "ed25519", "ed25519-dalek", "ethers", - "ethers-core", "futures", "hex", "log", @@ -5823,14 +5821,11 @@ version = "0.0.1" dependencies = [ "aes-gcm", "async-stream", - "async-trait", "bincode", "chrono", "diesel", "diesel_migrations", "ed25519-dalek", - "ethers", - "ethers-core", "futures", "hex", "libsqlite3-sys", @@ -5846,13 +5841,13 @@ dependencies = [ "serde", "serde_json", "sha2", - "smart-default", "thiserror", "tls_codec", "tokio", "tokio-stream", "toml 0.8.8", "tracing", + "trait-variant", "xmtp_cryptography", "xmtp_id", "xmtp_proto", @@ -5863,7 +5858,6 @@ dependencies = [ name = "xmtp_proto" version = "0.0.1" dependencies = [ - "async-trait", "futures", "futures-core", "openmls", @@ -5874,6 +5868,7 @@ dependencies = [ "prost-types", "serde", "tonic", + "trait-variant", ] [[package]] @@ -5910,7 +5905,6 @@ version = "0.0.1" dependencies = [ "env_logger", "ethers", - "ethers-core", "futures", "log", "parking_lot", diff --git a/bindings_ffi/Cargo.toml b/bindings_ffi/Cargo.toml index 4685dcb6f..89cfc4892 100644 --- a/bindings_ffi/Cargo.toml +++ b/bindings_ffi/Cargo.toml @@ -43,7 +43,6 @@ path = "src/bin.rs" [dev-dependencies] ethers = "2.0.13" -ethers-core = "2.0.13" tempfile = "3.5.0" tokio = { version = "1.28.1", features = ["full"] } tokio-test = "0.4" diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index 8f4a4121b..bf46033e1 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -961,7 +961,7 @@ impl FfiGroup { self.created_at_ns, ); let message = group - .process_streamed_group_message(envelope_bytes, self.inner_client.clone()) + .process_streamed_group_message(envelope_bytes, &self.inner_client) .await?; let ffi_message = message.into(); @@ -1576,11 +1576,11 @@ mod tests { }; use super::{create_client, FfiMessage, FfiMessageCallback, FfiXmtpClient}; - use ethers::utils::hex; - use ethers_core::rand::{ + use ethers::core::rand::{ self, distributions::{Alphanumeric, DistString}, }; + use ethers::utils::hex; use tokio::{sync::Notify, time::error::Elapsed}; use xmtp_cryptography::{signature::RecoverableSignature, utils::rng}; use xmtp_id::associations::generate_inbox_id; diff --git a/bindings_ffi/src/v2.rs b/bindings_ffi/src/v2.rs index 090306e97..f98a9a803 100644 --- a/bindings_ffi/src/v2.rs +++ b/bindings_ffi/src/v2.rs @@ -591,7 +591,7 @@ mod tests { let msg = "TestVector1"; let sig_hash = "19d6bec562518e365d07ba3cce26d08a5fffa2cbb1e7fe03c1f2d6a722fd3a5e544097b91f8f8cd11d43b032659f30529139ab1a9ecb6c81ed4a762179e87db81c"; - let sig_bytes = ethers_core::utils::hex::decode(sig_hash).unwrap(); + let sig_bytes = ethers::core::utils::hex::decode(sig_hash).unwrap(); let recovered_addr = crate::v2::recover_address(sig_bytes, msg.to_string()).unwrap(); assert_eq!(recovered_addr, addr.to_lowercase()); } diff --git a/bindings_node/Cargo.lock b/bindings_node/Cargo.lock index e5c2040a8..acfd0213f 100644 --- a/bindings_node/Cargo.lock +++ b/bindings_node/Cargo.lock @@ -2223,15 +2223,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -3275,7 +3266,7 @@ checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.12.1", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -3295,7 +3286,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.64", @@ -4074,17 +4065,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "smart-default" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eb01866308440fc64d6c44d9e86c5cc17adfe33c4d6eed55da9145044d0ffc1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.64", -] - [[package]] name = "socket2" version = "0.5.7" @@ -4647,6 +4627,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.64", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -5229,7 +5220,6 @@ dependencies = [ "curve25519-dalek", "ecdsa 0.16.9", "ethers", - "ethers-core", "getrandom", "hex", "k256 0.13.3", @@ -5252,7 +5242,6 @@ dependencies = [ "ed25519", "ed25519-dalek", "ethers", - "ethers-core", "futures", "hex", "log", @@ -5280,14 +5269,11 @@ version = "0.0.1" dependencies = [ "aes-gcm", "async-stream", - "async-trait", "bincode", "chrono", "diesel", "diesel_migrations", "ed25519-dalek", - "ethers", - "ethers-core", "futures", "hex", "libsqlite3-sys", @@ -5303,13 +5289,13 @@ dependencies = [ "serde", "serde_json", "sha2", - "smart-default", "thiserror", "tls_codec", "tokio", "tokio-stream", "toml", "tracing", + "trait-variant", "xmtp_cryptography", "xmtp_id", "xmtp_proto", @@ -5320,7 +5306,6 @@ dependencies = [ name = "xmtp_proto" version = "0.0.1" dependencies = [ - "async-trait", "futures", "futures-core", "openmls", @@ -5331,6 +5316,7 @@ dependencies = [ "prost-types", "serde", "tonic", + "trait-variant", ] [[package]] diff --git a/bindings_node/src/groups.rs b/bindings_node/src/groups.rs index acf79323a..a652ed2eb 100644 --- a/bindings_node/src/groups.rs +++ b/bindings_node/src/groups.rs @@ -201,7 +201,7 @@ impl NapiGroup { ); let envelope_bytes: Vec = envelope_bytes.deref().to_vec(); let message = group - .process_streamed_group_message(envelope_bytes, self.inner_client.clone()) + .process_streamed_group_message(envelope_bytes, &self.inner_client) .await .map_err(ErrorWrapper::from)?; diff --git a/examples/cli/Cargo.toml b/examples/cli/Cargo.toml index 861e168a0..3f8329805 100644 --- a/examples/cli/Cargo.toml +++ b/examples/cli/Cargo.toml @@ -15,7 +15,6 @@ path = "cli-client.rs" [dependencies] clap = { version = "4.4.6", features = ["derive"] } ethers = "2.0.4" -ethers-core = "2.0.4" femme = "2.2.1" futures = "0.3.28" hex = "0.4.3" diff --git a/xmtp_api_grpc/src/grpc_api_helper.rs b/xmtp_api_grpc/src/grpc_api_helper.rs index 7f9235405..6e61a81ac 100644 --- a/xmtp_api_grpc/src/grpc_api_helper.rs +++ b/xmtp_api_grpc/src/grpc_api_helper.rs @@ -7,13 +7,13 @@ use futures::stream::{AbortHandle, Abortable}; use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; use tokio::sync::oneshot; use tonic::transport::ClientTlsConfig; -use tonic::{async_trait, metadata::MetadataValue, transport::Channel, Request, Streaming}; +use tonic::{metadata::MetadataValue, transport::Channel, Request, Streaming}; -use xmtp_proto::api_client::ClientWithMetadata; +use xmtp_proto::api_client::{ClientWithMetadata, XmtpMlsStreams}; +use xmtp_proto::xmtp::mls::api::v1::{GroupMessage, WelcomeMessage}; use xmtp_proto::{ api_client::{ - Error, ErrorKind, GroupMessageStream, MutableApiSubscription, WelcomeMessageStream, - XmtpApiClient, XmtpApiSubscription, XmtpMlsClient, + Error, ErrorKind, MutableApiSubscription, XmtpApiClient, XmtpApiSubscription, XmtpMlsClient, }, xmtp::identity::api::v1::identity_api_client::IdentityApiClient as ProtoIdentityApiClient, xmtp::message_api::v1::{ @@ -131,7 +131,6 @@ impl ClientWithMetadata for Client { } } -#[async_trait] impl XmtpApiClient for Client { type Subscription = Subscription; type MutableSubscription = GrpcMutableSubscription; @@ -318,7 +317,6 @@ impl Stream for GrpcMutableSubscription { } } -#[async_trait] impl MutableApiSubscription for GrpcMutableSubscription { async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error> { self.update_channel @@ -334,8 +332,6 @@ impl MutableApiSubscription for GrpcMutableSubscription { self.update_channel.close_channel(); } } - -#[async_trait] impl XmtpMlsClient for Client { #[tracing::instrument(level = "trace", skip_all)] async fn upload_key_package(&self, req: UploadKeyPackageRequest) -> Result<(), Error> { @@ -405,11 +401,62 @@ impl XmtpMlsClient for Client { res.map(|r| r.into_inner()) .map_err(|e| Error::new(ErrorKind::MlsError).with(e)) } +} + +pub struct GroupMessageStream { + inner: tonic::codec::Streaming, +} + +impl From> for GroupMessageStream { + fn from(inner: tonic::codec::Streaming) -> Self { + GroupMessageStream { inner } + } +} + +impl Stream for GroupMessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner + .poll_next_unpin(cx) + .map(|data| data.map(|v| v.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)))) + } +} + +pub struct WelcomeMessageStream { + inner: tonic::codec::Streaming, +} + +impl From> for WelcomeMessageStream { + fn from(inner: tonic::codec::Streaming) -> Self { + WelcomeMessageStream { inner } + } +} + +impl Stream for WelcomeMessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner + .poll_next_unpin(cx) + .map(|data| data.map(|v| v.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)))) + } +} + +impl XmtpMlsStreams for Client { + type GroupMessageStream<'a> = GroupMessageStream; + type WelcomeMessageStream<'a> = WelcomeMessageStream; async fn subscribe_group_messages( &self, req: SubscribeGroupMessagesRequest, - ) -> Result { + ) -> Result, Error> { let client = &mut self.mls_client.clone(); let res = client .subscribe_group_messages(self.build_request(req)) @@ -417,16 +464,13 @@ impl XmtpMlsClient for Client { .map_err(|e| Error::new(ErrorKind::MlsError).with(e))?; let stream = res.into_inner(); - - let new_stream = stream.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)); - - Ok(Box::pin(new_stream)) + Ok(stream.into()) } async fn subscribe_welcome_messages( &self, req: SubscribeWelcomeMessagesRequest, - ) -> Result { + ) -> Result, Error> { let client = &mut self.mls_client.clone(); let res = client .subscribe_welcome_messages(self.build_request(req)) @@ -435,8 +479,6 @@ impl XmtpMlsClient for Client { let stream = res.into_inner(); - let new_stream = stream.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)); - - Ok(Box::pin(new_stream)) + Ok(stream.into()) } } diff --git a/xmtp_api_grpc/src/identity.rs b/xmtp_api_grpc/src/identity.rs index 0fbb27e6d..92531e0a6 100644 --- a/xmtp_api_grpc/src/identity.rs +++ b/xmtp_api_grpc/src/identity.rs @@ -1,4 +1,3 @@ -use tonic::async_trait; use xmtp_proto::{ api_client::{Error, ErrorKind, XmtpIdentityClient}, xmtp::identity::api::v1::{ @@ -10,7 +9,6 @@ use xmtp_proto::{ use crate::Client; -#[async_trait] impl XmtpIdentityClient for Client { #[tracing::instrument(level = "trace", skip_all)] async fn publish_identity_update( diff --git a/xmtp_api_grpc/src/lib.rs b/xmtp_api_grpc/src/lib.rs index 10472c2a8..a774b86bc 100644 --- a/xmtp_api_grpc/src/lib.rs +++ b/xmtp_api_grpc/src/lib.rs @@ -5,7 +5,7 @@ mod identity; pub const LOCALHOST_ADDRESS: &str = "http://localhost:5556"; pub const DEV_ADDRESS: &str = "https://grpc.dev.xmtp.network:443"; -pub use grpc_api_helper::Client; +pub use grpc_api_helper::{Client, GroupMessageStream, WelcomeMessageStream}; #[cfg(test)] mod tests { diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index 13852a3dc..2a2b9cb8d 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -8,8 +8,6 @@ crate-type = ["cdylib", "rlib"] [dependencies] async-stream.workspace = true -async-trait = { workspace = true } -bytes = "1.7" futures = { workspace = true } log.workspace = true reqwest = { version = "0.12.5", features = ["json", "stream"] } diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index aa6fca4b4..d204b106d 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -3,7 +3,7 @@ pub mod constants; mod util; -use async_trait::async_trait; +use futures::stream; use reqwest::header; use util::{create_grpc_stream, handle_error}; use xmtp_proto::api_client::{ClientWithMetadata, Error, ErrorKind, XmtpIdentityClient}; @@ -14,7 +14,7 @@ use xmtp_proto::xmtp::identity::api::v1::{ }; use xmtp_proto::xmtp::mls::api::v1::{GroupMessage, WelcomeMessage}; use xmtp_proto::{ - api_client::{GroupMessageStream, WelcomeMessageStream, XmtpMlsClient}, + api_client::{XmtpMlsClient, XmtpMlsStreams}, xmtp::mls::api::v1::{ FetchKeyPackagesRequest, FetchKeyPackagesResponse, QueryGroupMessagesRequest, QueryGroupMessagesResponse, QueryWelcomeMessagesRequest, QueryWelcomeMessagesResponse, @@ -120,8 +120,6 @@ impl ClientWithMetadata for XmtpHttpApiClient { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl XmtpMlsClient for XmtpHttpApiClient { async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error> { let res = self @@ -230,11 +228,29 @@ impl XmtpMlsClient for XmtpHttpApiClient { log::debug!("query_welcome_messages"); handle_error(&*res) } +} + +impl XmtpMlsStreams for XmtpHttpApiClient { + // hard to avoid boxing here: + // 1.) use `hyper` instead of `reqwest` and create our own `Stream` type + // 2.) ise `impl Stream` in return of `XmtpMlsStreams` but that + // breaks the `mockall::` functionality, since `mockall` does not support `impl Trait` in + // `Trait` yet. + + #[cfg(not(target_arch = "wasm32"))] + type GroupMessageStream<'a> = stream::BoxStream<'a, Result>; + #[cfg(not(target_arch = "wasm32"))] + type WelcomeMessageStream<'a> = stream::BoxStream<'a, Result>; + + #[cfg(target_arch = "wasm32")] + type GroupMessageStream<'a> = stream::LocalBoxStream<'a, Result>; + #[cfg(target_arch = "wasm32")] + type WelcomeMessageStream<'a> = stream::LocalBoxStream<'a, Result>; async fn subscribe_group_messages( &self, request: SubscribeGroupMessagesRequest, - ) -> Result { + ) -> Result, Error> { log::debug!("subscribe_group_messages"); Ok(create_grpc_stream::<_, GroupMessage>( request, @@ -246,7 +262,7 @@ impl XmtpMlsClient for XmtpHttpApiClient { async fn subscribe_welcome_messages( &self, request: SubscribeWelcomeMessagesRequest, - ) -> Result { + ) -> Result, Error> { log::debug!("subscribe_welcome_messages"); Ok(create_grpc_stream::<_, WelcomeMessage>( request, @@ -256,8 +272,6 @@ impl XmtpMlsClient for XmtpHttpApiClient { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl XmtpIdentityClient for XmtpHttpApiClient { async fn publish_identity_update( &self, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index 056471eef..52051b1b8 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,4 +1,7 @@ -use futures::{stream, Stream, StreamExt}; +use futures::{ + stream::{self, StreamExt}, + Stream, +}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Deserializer; use std::io::Read; @@ -49,9 +52,7 @@ pub fn create_grpc_stream< endpoint: String, http_client: reqwest::Client, ) -> stream::LocalBoxStream<'static, Result> { - log::info!("About to spawn stream"); - let stream = create_grpc_stream_inner(request, endpoint, http_client); - stream.boxed_local() + create_grpc_stream_inner(request, endpoint, http_client).boxed_local() } #[cfg(not(target_arch = "wasm32"))] @@ -63,12 +64,10 @@ pub fn create_grpc_stream< endpoint: String, http_client: reqwest::Client, ) -> stream::BoxStream<'static, Result> { - log::info!("About to spawn stream"); - let stream = create_grpc_stream_inner(request, endpoint, http_client); - stream.boxed() + create_grpc_stream_inner(request, endpoint, http_client).boxed() } -fn create_grpc_stream_inner< +pub fn create_grpc_stream_inner< T: Serialize + Send + 'static, R: DeserializeOwned + Send + std::fmt::Debug + 'static, >( diff --git a/xmtp_cryptography/Cargo.toml b/xmtp_cryptography/Cargo.toml index 7e06749f0..67ab6baa2 100644 --- a/xmtp_cryptography/Cargo.toml +++ b/xmtp_cryptography/Cargo.toml @@ -8,7 +8,6 @@ version.workspace = true curve25519-dalek = "4" ecdsa = "0.16.9" ethers = { workspace = true } -ethers-core = { workspace = true } hex = { workspace = true } k256 = { version = "0.13.3", features = ["ecdh"] } log = { workspace = true } diff --git a/xmtp_cryptography/src/signature.rs b/xmtp_cryptography/src/signature.rs index 4f25414fe..687b3bce8 100644 --- a/xmtp_cryptography/src/signature.rs +++ b/xmtp_cryptography/src/signature.rs @@ -1,6 +1,6 @@ use curve25519_dalek::{edwards::CompressedEdwardsY, traits::IsIdentity}; +use ethers::core::types::{self as ethers_types, H160}; use ethers::types::Address; -use ethers_core::types::{self as ethers_types, H160}; pub use k256::ecdsa::{RecoveryId, SigningKey, VerifyingKey}; use k256::Secp256k1; use serde::{Deserialize, Serialize}; @@ -98,8 +98,8 @@ impl From<(ecdsa::Signature, RecoveryId)> for RecoverableSignature { } } -impl From for RecoverableSignature { - fn from(value: ethers_core::types::Signature) -> Self { +impl From for RecoverableSignature { + fn from(value: ethers::core::types::Signature) -> Self { RecoverableSignature::Eip191Signature(value.to_vec()) } } diff --git a/xmtp_cryptography/src/utils.rs b/xmtp_cryptography/src/utils.rs index efbc1c502..35e13d4ba 100644 --- a/xmtp_cryptography/src/utils.rs +++ b/xmtp_cryptography/src/utils.rs @@ -1,5 +1,5 @@ +use ethers::core::utils::keccak256; pub use ethers::prelude::LocalWallet; -use ethers_core::utils::keccak256; use k256::ecdsa::VerifyingKey; use rand::{CryptoRng, RngCore, SeedableRng}; use rand_chacha::ChaCha20Rng; diff --git a/xmtp_id/Cargo.toml b/xmtp_id/Cargo.toml index 449c5f5fa..94f82f343 100644 --- a/xmtp_id/Cargo.toml +++ b/xmtp_id/Cargo.toml @@ -8,7 +8,6 @@ async-trait.workspace = true chrono.workspace = true ed25519-dalek = { workspace = true, features = ["digest"] } ed25519.workspace = true -ethers-core.workspace = true ethers.workspace = true futures.workspace = true hex.workspace = true diff --git a/xmtp_id/src/associations/association_log.rs b/xmtp_id/src/associations/association_log.rs index 2bccf68d0..bf54b4688 100644 --- a/xmtp_id/src/associations/association_log.rs +++ b/xmtp_id/src/associations/association_log.rs @@ -3,7 +3,6 @@ use super::member::{Member, MemberIdentifier, MemberKind}; use super::serialization::{from_identity_update_proto, DeserializationError}; use super::signature::{Signature, SignatureError, SignatureKind}; use super::state::AssociationState; -use async_trait::async_trait; use prost::Message; use thiserror::Error; use xmtp_proto::xmtp::identity::associations::IdentityUpdate as IdentityUpdateProto; @@ -38,8 +37,7 @@ pub enum AssociationError { MissingIdentityUpdate, } -#[async_trait] -pub trait IdentityAction: Send + 'static { +pub(crate) trait IdentityAction: Send + 'static { async fn update_state( &self, existing_state: Option, @@ -65,7 +63,6 @@ pub struct CreateInbox { pub initial_address_signature: Box, } -#[async_trait] impl IdentityAction for CreateInbox { async fn update_state( &self, @@ -110,7 +107,6 @@ pub struct AddAssociation { pub existing_member_signature: Box, } -#[async_trait::async_trait] impl IdentityAction for AddAssociation { async fn update_state( &self, @@ -205,7 +201,6 @@ pub struct RevokeAssociation { pub revoked_member: MemberIdentifier, } -#[async_trait] impl IdentityAction for RevokeAssociation { async fn update_state( &self, @@ -261,7 +256,6 @@ pub struct ChangeRecoveryAddress { pub new_recovery_address: String, } -#[async_trait] impl IdentityAction for ChangeRecoveryAddress { async fn update_state( &self, @@ -299,7 +293,6 @@ pub enum Action { ChangeRecoveryAddress(ChangeRecoveryAddress), } -#[async_trait] impl IdentityAction for Action { async fn update_state( &self, @@ -366,7 +359,6 @@ impl TryFrom> for IdentityUpdate { } } -#[async_trait] impl IdentityAction for IdentityUpdate { async fn update_state( &self, diff --git a/xmtp_id/src/associations/mod.rs b/xmtp_id/src/associations/mod.rs index 198a84a26..7e1c423d0 100644 --- a/xmtp_id/src/associations/mod.rs +++ b/xmtp_id/src/associations/mod.rs @@ -16,6 +16,8 @@ pub use self::serialization::{map_vec, try_map_vec, DeserializationError}; pub use self::signature::*; pub use self::state::{AssociationState, AssociationStateDiff}; +use crate::associations::association_log::IdentityAction; + // Apply a single IdentityUpdate to an existing AssociationState pub async fn apply_update( initial_state: AssociationState, diff --git a/xmtp_id/src/associations/signature.rs b/xmtp_id/src/associations/signature.rs index 7a4190d51..fefb5fb4c 100644 --- a/xmtp_id/src/associations/signature.rs +++ b/xmtp_id/src/associations/signature.rs @@ -182,8 +182,6 @@ pub struct SmartContractWalletSignature { chain_rpc_url: String, } -unsafe impl Send for SmartContractWalletSignature {} - impl SmartContractWalletSignature { pub fn new( signature_text: String, diff --git a/xmtp_id/src/lib.rs b/xmtp_id/src/lib.rs index bf7ee091a..f6fa1ad29 100644 --- a/xmtp_id/src/lib.rs +++ b/xmtp_id/src/lib.rs @@ -40,7 +40,6 @@ pub async fn is_smart_contract( Ok(!code.is_empty()) } -// TODO: Remove this trait pub trait InboxOwner { /// Get address of the wallet. fn get_address(&self) -> String; @@ -54,7 +53,7 @@ impl InboxOwner for LocalWallet { } fn sign(&self, text: &str) -> Result { - let message_hash = ethers_core::utils::hash_message(text); + let message_hash = ethers::core::utils::hash_message(text); Ok(self.sign_hash(message_hash)?.to_vec().into()) } } diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index c992137ba..91ab0ded0 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -26,7 +26,7 @@ test-utils = [] [dependencies] aes-gcm = { version = "0.10.3", features = ["std"] } async-stream.workspace = true -async-trait.workspace = true +trait-variant.workspace = true bincode = "1.3.3" chrono = { workspace = true } diesel = { version = "2.2.2", features = [ @@ -36,8 +36,6 @@ diesel = { version = "2.2.2", features = [ ] } diesel_migrations = { version = "2.2.0", features = ["sqlite"] } ed25519-dalek = "2.1.1" -ethers-core.workspace = true -ethers.workspace = true futures.workspace = true hex.workspace = true libsqlite3-sys = { version = "0.29.0", optional = true } @@ -53,7 +51,6 @@ reqwest = { version = "0.12.4", features = ["stream"] } serde = { workspace = true } serde_json.workspace = true sha2.workspace = true -smart-default = "0.7.1" thiserror = { workspace = true } tls_codec = { workspace = true } tokio = { workspace = true, features = [ @@ -83,7 +80,6 @@ anyhow.workspace = true async-barrier = "1.1" criterion = { version = "0.5", features = ["html_reports", "async_tokio"] } ctor.workspace = true -flume = "0.11" mockall = "0.13.0" mockito = "1.4.0" tempfile = "3.5.0" @@ -93,6 +89,7 @@ tracing-test = "0.2.4" tracing.workspace = true xmtp_api_grpc = { path = "../xmtp_api_grpc" } xmtp_id = { path = "../xmtp_id", features = ["test-utils"] } +ethers.workspace = true [[bench]] harness = false diff --git a/xmtp_mls/benches/group_limit.rs b/xmtp_mls/benches/group_limit.rs index 8408bd305..6d41676c2 100755 --- a/xmtp_mls/benches/group_limit.rs +++ b/xmtp_mls/benches/group_limit.rs @@ -3,11 +3,9 @@ //! using `RUST_LOG=trace` will additionally output a `tracing.folded` file, which //! may be used to generate a flamegraph of execution from tracing logs. use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; -use ethers::signers::LocalWallet; use std::{collections::HashMap, sync::Arc}; use tokio::runtime::{Builder, Handle, Runtime}; use tracing::{trace_span, Instrument}; -use xmtp_cryptography::utils::rng; use xmtp_mls::{ builder::ClientBuilder, groups::GroupMetadataOptions, @@ -33,7 +31,7 @@ fn setup() -> (Arc, Vec, Runtime) { .unwrap(); let (client, identities) = runtime.block_on(async { - let wallet = LocalWallet::new(&mut rng()); + let wallet = xmtp_cryptography::utils::generate_local_wallet(); // use dev network if `DEV_GRPC` is set let dev = std::env::var("DEV_GRPC"); diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 7ac073f26..dd79a1a76 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -2,9 +2,7 @@ use std::collections::HashMap; use super::ApiClientWrapper; use crate::{retry_async, XmtpApi}; -use xmtp_proto::api_client::{ - Error as ApiError, ErrorKind, GroupMessageStream, WelcomeMessageStream, -}; +use xmtp_proto::api_client::{Error as ApiError, ErrorKind}; use xmtp_proto::xmtp::mls::api::v1::{ group_message_input::{Version as GroupMessageInputVersion, V1 as GroupMessageInputV1}, subscribe_group_messages_request::Filter as GroupFilterProto, @@ -264,7 +262,7 @@ where pub async fn subscribe_group_messages( &self, filters: Vec, - ) -> Result { + ) -> Result> + '_, ApiError> { self.api_client .subscribe_group_messages(SubscribeGroupMessagesRequest { filters: filters.into_iter().map(|f| f.into()).collect(), @@ -276,7 +274,7 @@ where &self, installation_key: Vec, id_cursor: Option, - ) -> Result { + ) -> Result> + '_, ApiError> { self.api_client .subscribe_welcome_messages(SubscribeWelcomeMessagesRequest { filters: vec![WelcomeFilterProto { diff --git a/xmtp_mls/src/api/test_utils.rs b/xmtp_mls/src/api/test_utils.rs index 973ab17b4..604804aa5 100644 --- a/xmtp_mls/src/api/test_utils.rs +++ b/xmtp_mls/src/api/test_utils.rs @@ -1,24 +1,26 @@ -use async_trait::async_trait; use mockall::mock; use xmtp_proto::{ - api_client::{ - ClientWithMetadata, Error, GroupMessageStream, WelcomeMessageStream, XmtpIdentityClient, - XmtpMlsClient, - }, - xmtp::identity::api::v1::{ - GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, - GetIdentityUpdatesResponse as GetIdentityUpdatesV2Response, GetInboxIdsRequest, - GetInboxIdsResponse, PublishIdentityUpdateRequest, PublishIdentityUpdateResponse, - }, - xmtp::mls::api::v1::{ - group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, - FetchKeyPackagesRequest, FetchKeyPackagesResponse, GroupMessage, QueryGroupMessagesRequest, - QueryGroupMessagesResponse, QueryWelcomeMessagesRequest, QueryWelcomeMessagesResponse, - SendGroupMessagesRequest, SendWelcomeMessagesRequest, SubscribeGroupMessagesRequest, - SubscribeWelcomeMessagesRequest, UploadKeyPackageRequest, + api_client::{ClientWithMetadata, Error, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams}, + xmtp::{ + identity::api::v1::{ + GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, + GetIdentityUpdatesResponse as GetIdentityUpdatesV2Response, GetInboxIdsRequest, + GetInboxIdsResponse, PublishIdentityUpdateRequest, PublishIdentityUpdateResponse, + }, + mls::api::v1::{ + group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, + FetchKeyPackagesRequest, FetchKeyPackagesResponse, GroupMessage, + QueryGroupMessagesRequest, QueryGroupMessagesResponse, QueryWelcomeMessagesRequest, + QueryWelcomeMessagesResponse, SendGroupMessagesRequest, SendWelcomeMessagesRequest, + SubscribeGroupMessagesRequest, SubscribeWelcomeMessagesRequest, + UploadKeyPackageRequest, + }, }, }; +#[cfg(feature = "http-api")] +use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; + use crate::XmtpTestClient; pub fn build_group_messages(num_messages: usize, group_id: Vec) -> Vec { @@ -46,7 +48,6 @@ mock! { fn set_app_version(&mut self, version: String) -> Result<(), Error>; } - #[async_trait] impl XmtpMlsClient for ApiClient { async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; async fn fetch_key_packages( @@ -57,18 +58,30 @@ mock! { async fn send_welcome_messages(&self, request: SendWelcomeMessagesRequest) -> Result<(), Error>; async fn query_group_messages(&self, request: QueryGroupMessagesRequest) -> Result; async fn query_welcome_messages(&self, request: QueryWelcomeMessagesRequest) -> Result; - async fn subscribe_group_messages(&self, request: SubscribeGroupMessagesRequest) -> Result; - async fn subscribe_welcome_messages(&self, request: SubscribeWelcomeMessagesRequest) -> Result; } - #[async_trait] + impl XmtpMlsStreams for ApiClient { + #[cfg(not(feature = "http-api"))] + type GroupMessageStream<'a> = xmtp_api_grpc::GroupMessageStream; + #[cfg(not(feature = "http-api"))] + type WelcomeMessageStream<'a> = xmtp_api_grpc::WelcomeMessageStream; + + #[cfg(feature = "http-api")] + type GroupMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + #[cfg(feature = "http-api")] + type WelcomeMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + + + async fn subscribe_group_messages(&self, request: SubscribeGroupMessagesRequest) -> Result<::GroupMessageStream<'static>, Error>; + async fn subscribe_welcome_messages(&self, request: SubscribeWelcomeMessagesRequest) -> Result<::WelcomeMessageStream<'static>, Error>; + } + impl XmtpIdentityClient for ApiClient { async fn publish_identity_update(&self, request: PublishIdentityUpdateRequest) -> Result; async fn get_identity_updates_v2(&self, request: GetIdentityUpdatesV2Request) -> Result; async fn get_inbox_ids(&self, request: GetInboxIdsRequest) -> Result; } - #[async_trait] impl XmtpTestClient for ApiClient { async fn create_local() -> Self { ApiClient } async fn create_dev() -> Self { ApiClient } diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index dce196d84..cf1353cf3 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -142,13 +142,10 @@ mod tests { api::test_utils::*, identity::Identity, storage::identity::StoredIdentity, utils::test::rand_vec, Store, }; - use ethers::signers::Signer; - use ethers_core::k256; use openmls::credentials::{Credential, CredentialType}; use openmls_basic_credential::SignatureKeyPair; use openmls_traits::types::SignatureScheme; use prost::Message; - use xmtp_cryptography::signature::h160addr_to_string; use xmtp_cryptography::utils::{generate_local_wallet, rng}; use xmtp_id::associations::ValidatedLegacySignedPublicKey; use xmtp_id::associations::{ @@ -188,10 +185,10 @@ mod tests { /// Generate a random legacy key proto bytes and corresponding account address. async fn generate_random_legacy_key() -> (Vec, String) { let wallet = generate_local_wallet(); - let address = h160addr_to_string(wallet.address()); + let address = wallet.get_address(); let created_ns = rand_u64(); - let secret_key = k256::ecdsa::SigningKey::random(&mut rng()); - let public_key = k256::ecdsa::VerifyingKey::from(&secret_key); + let secret_key = ethers::core::k256::ecdsa::SigningKey::random(&mut rng()); + let public_key = ethers::core::k256::ecdsa::VerifyingKey::from(&secret_key); let public_key_bytes = public_key.to_sec1_bytes().to_vec(); let mut public_key_buf = vec![]; UnsignedPublicKey { @@ -205,7 +202,7 @@ mod tests { .encode(&mut public_key_buf) .unwrap(); let message = ValidatedLegacySignedPublicKey::text(&public_key_buf); - let signed_public_key = wallet.sign_message(message).await.unwrap().to_vec(); + let signed_public_key: Vec = wallet.sign(&message).unwrap().into(); let (bytes, recovery_id) = signed_public_key.as_slice().split_at(64); let recovery_id = recovery_id[0]; let signed_private_key: SignedPrivateKey = SignedPrivateKey { diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index cb8894271..05fe149f7 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::pin::Pin; use std::sync::Arc; use futures::Stream; @@ -18,7 +17,7 @@ impl MlsGroup { pub(crate) async fn process_stream_entry( &self, envelope: GroupMessage, - client: Arc>, + client: &Client, ) -> Result, GroupError> where ApiClient: XmtpApi, @@ -34,11 +33,9 @@ impl MlsGroup { let created_ns = msgv1.created_ns; if !self.has_already_synced(msg_id).await? { - let client_pointer = client.clone(); let process_result = retry_async!( Retry::default(), (async { - let client_pointer = client_pointer.clone(); let client_id = client_id.clone(); let msgv1 = msgv1.clone(); self.context @@ -55,7 +52,7 @@ impl MlsGroup { ); self.process_message( - client_pointer.as_ref(), + client, &mut openmls_group, &provider, &msgv1, @@ -71,7 +68,7 @@ impl MlsGroup { if let Some(GroupError::ReceiveError(_)) = process_result.as_ref().err() { // Swallow errors here, since another process may have successfully saved the message // to the DB - match self.sync_with_conn(&client.mls_provider()?, &client).await { + match self.sync_with_conn(&client.mls_provider()?, client).await { Ok(_) => { log::debug!("Sync triggered by streamed message successful") } @@ -109,7 +106,7 @@ impl MlsGroup { pub async fn process_streamed_group_message( &self, envelope_bytes: Vec, - client: Arc>, + client: &Client, ) -> Result where ApiClient: XmtpApi, @@ -121,21 +118,21 @@ impl MlsGroup { message.ok_or(GroupError::MissingMessage) } - pub async fn stream( - &self, - client: Arc>, - ) -> Result + Send + '_>>, GroupError> + pub async fn stream<'a, ApiClient>( + &'a self, + client: &'a Client, + ) -> Result + '_, GroupError> where - ApiClient: crate::XmtpApi, + ApiClient: crate::XmtpApi + 'static, { Ok(client - .stream_messages(HashMap::from([( + .stream_messages(Arc::new(HashMap::from([( self.group_id.clone(), MessagesStreamInfo { convo_created_at_ns: self.created_at_ns, cursor: 0, }, - )])) + )]))) .await?) } @@ -146,7 +143,7 @@ impl MlsGroup { callback: impl FnMut(StoredGroupMessage) + Send + 'static, ) -> StreamHandle> where - ApiClient: crate::XmtpApi, + ApiClient: crate::XmtpApi + 'static, { Client::::stream_messages_with_callback( client, @@ -202,7 +199,7 @@ mod tests { let mut message_bytes: Vec = Vec::new(); message.encode(&mut message_bytes).unwrap(); let message_again = amal_group - .process_streamed_group_message(message_bytes, Arc::new(amal)) + .process_streamed_group_message(message_bytes, &amal) .await; if let Ok(message) = message_again { @@ -237,7 +234,8 @@ mod tests { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let mut stream = UnboundedReceiverStream::new(rx); tokio::spawn(async move { - let mut stream = bola_group_ptr.stream(bola_ptr).await.unwrap(); + let stream = bola_group_ptr.stream(&bola_ptr).await.unwrap(); + futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); notify_ptr.notify_one(); @@ -281,7 +279,8 @@ mod tests { let amal_ptr = amal.clone(); let group_ptr = group.clone(); tokio::spawn(async move { - let mut stream = group_ptr.stream(amal_ptr).await.unwrap(); + let stream = group_ptr.stream(&amal_ptr).await.unwrap(); + futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); } @@ -323,7 +322,8 @@ mod tests { let (start_tx, start_rx) = tokio::sync::oneshot::channel(); let mut stream = UnboundedReceiverStream::new(rx); tokio::spawn(async move { - let mut stream = amal_group_ptr.stream(amal_ptr).await.unwrap(); + let stream = amal_group_ptr.stream(&amal_ptr).await.unwrap(); + futures::pin_mut!(stream); let _ = start_tx.send(()); while let Some(item) = stream.next().await { let _ = tx.send(item); diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index 27eadbfd2..bbbf77aee 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -15,7 +15,6 @@ use crate::{ }; use crate::{retryable, Fetch, Store}; use ed25519_dalek::SigningKey; -use ethers::signers::WalletError; use log::debug; use log::info; use openmls::prelude::tls_codec::Serialize; @@ -141,8 +140,6 @@ pub enum IdentityError { #[error("legacy key does not match address")] LegacyKeyMismatch, #[error(transparent)] - WalletError(#[from] WalletError), - #[error(transparent)] OpenMls(#[from] openmls::prelude::Error), #[error(transparent)] StorageError(#[from] crate::storage::StorageError), diff --git a/xmtp_mls/src/identity_updates.rs b/xmtp_mls/src/identity_updates.rs index b4350eb53..e9860755e 100644 --- a/xmtp_mls/src/identity_updates.rs +++ b/xmtp_mls/src/identity_updates.rs @@ -445,7 +445,6 @@ pub async fn load_identity_updates( #[cfg(test)] pub(crate) mod tests { - use ethers::signers::LocalWallet; use tracing_test::traced_test; use xmtp_cryptography::utils::generate_local_wallet; use xmtp_id::{ @@ -465,7 +464,7 @@ pub(crate) mod tests { use super::load_identity_updates; pub(crate) async fn sign_with_wallet( - wallet: &LocalWallet, + wallet: &impl InboxOwner, signature_request: &mut SignatureRequest, ) { let wallet_signature: Vec = wallet diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 9d330f0bc..83afcb6bb 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -11,7 +11,6 @@ mod hpke; pub mod identity; mod identity_updates; mod mutex_registry; -pub mod owner; pub mod retry; pub mod storage; pub mod subscriptions; @@ -22,46 +21,83 @@ mod xmtp_openmls_provider; pub use client::{Client, Network}; use storage::StorageError; -use xmtp_cryptography::signature::{RecoverableSignature, SignatureError}; -use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient}; + +pub use trait_impls::*; /// XMTP Api Super Trait /// Implements all Trait Network APIs for convenience. -#[cfg(not(test))] -pub trait XmtpApi -where - Self: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata, -{ -} -#[cfg(not(test))] -impl XmtpApi for T where T: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + ?Sized {} +mod trait_impls { + pub use inner::*; + + // native, release + #[cfg(not(test))] + mod inner { + use xmtp_proto::api_client::{ + ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams, + }; -#[cfg(test)] -pub trait XmtpApi -where - Self: XmtpMlsClient + XmtpIdentityClient + XmtpTestClient + ClientWithMetadata, -{ -} + pub trait XmtpApi + where + Self: XmtpMlsClient + + XmtpMlsStreams + + XmtpIdentityClient + + ClientWithMetadata + + Send + + Sync, + { + } + impl XmtpApi for T where + T: XmtpMlsClient + + XmtpMlsStreams + + XmtpIdentityClient + + ClientWithMetadata + + Send + + Sync + + ?Sized + { + } + } -#[cfg(test)] -impl XmtpApi for T where - T: XmtpMlsClient + XmtpIdentityClient + XmtpTestClient + ClientWithMetadata + ?Sized -{ + // test, native + #[cfg(test)] + mod inner { + use xmtp_proto::api_client::{ + ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams, + }; + + pub trait XmtpApi + where + Self: XmtpMlsClient + + XmtpMlsStreams + + XmtpIdentityClient + + crate::XmtpTestClient + + ClientWithMetadata + + Send + + Sync, + { + } + impl XmtpApi for T where + T: XmtpMlsClient + + XmtpMlsStreams + + XmtpIdentityClient + + crate::XmtpTestClient + + ClientWithMetadata + + Send + + Sync + + ?Sized + { + } + } } #[cfg(any(test, feature = "test-utils", feature = "bench"))] -#[async_trait::async_trait] -pub trait XmtpTestClient { +#[trait_variant::make(XmtpTestClient: Send)] +pub trait LocalXmtpTestClient { async fn create_local() -> Self; async fn create_dev() -> Self; } -pub trait InboxOwner { - /// Get address of the wallet. - fn get_address(&self) -> String; - /// Sign text with the wallet. - fn sign(&self, text: &str) -> Result; -} +pub use xmtp_id::InboxOwner; /// Inserts a model to the underlying data store, erroring if it already exists pub trait Store { diff --git a/xmtp_mls/src/owner/evm_owner.rs b/xmtp_mls/src/owner/evm_owner.rs deleted file mode 100644 index bbf4fbf19..000000000 --- a/xmtp_mls/src/owner/evm_owner.rs +++ /dev/null @@ -1,16 +0,0 @@ -pub use ethers::signers::{LocalWallet, Signer}; - -use xmtp_cryptography::signature::{h160addr_to_string, RecoverableSignature, SignatureError}; - -use crate::InboxOwner; - -impl InboxOwner for LocalWallet { - fn get_address(&self) -> String { - h160addr_to_string(self.address()) - } - - fn sign(&self, text: &str) -> Result { - let message_hash = ethers_core::utils::hash_message(text); - Ok(self.sign_hash(message_hash)?.to_vec().into()) - } -} diff --git a/xmtp_mls/src/owner/mod.rs b/xmtp_mls/src/owner/mod.rs deleted file mode 100644 index 1b15ea292..000000000 --- a/xmtp_mls/src/owner/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod evm_owner; diff --git a/xmtp_mls/src/retry.rs b/xmtp_mls/src/retry.rs index 6bab96703..8d844b4f4 100644 --- a/xmtp_mls/src/retry.rs +++ b/xmtp_mls/src/retry.rs @@ -19,7 +19,6 @@ use std::time::Duration; use rand::Rng; -use smart_default::SmartDefault; /// Specifies which errors are retryable. /// All Errors are not retryable by-default. @@ -28,19 +27,26 @@ pub trait RetryableError: std::error::Error { } /// Options to specify how to retry a function -#[derive(SmartDefault, Debug, PartialEq, Eq, Copy, Clone)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] pub struct Retry { - #[default = 5] retries: usize, - #[default(_code = "std::time::Duration::from_millis(50)")] duration: std::time::Duration, - #[default = 3] // The amount to multiply the duration on each subsequent attempt multiplier: u32, - #[default = 25] max_jitter_ms: usize, } +impl Default for Retry { + fn default() -> Self { + Self { + retries: 5, + duration: std::time::Duration::from_millis(50), + multiplier: 3, + max_jitter_ms: 25, + } + } +} + impl Retry { /// Get the number of retries this is configured with. pub fn retries(&self) -> usize { @@ -119,7 +125,7 @@ impl Retry { /// ``` /// use xmtp_mls::{retry_async, retry::{RetryableError, Retry}}; /// use thiserror::Error; -/// use flume::bounded; +/// use tokio::sync::mpsc; /// /// #[derive(Debug, Error)] /// enum MyError { @@ -138,8 +144,8 @@ impl Retry { /// } /// } /// -/// async fn fallable_fn(rx: &flume::Receiver) -> Result<(), MyError> { -/// if rx.recv_async().await.unwrap() == 2 { +/// async fn fallable_fn(rx: &mut mpsc::Receiver) -> Result<(), MyError> { +/// if rx.recv().await.unwrap() == 2 { /// return Ok(()); /// } /// Err(MyError::Retryable) @@ -147,14 +153,14 @@ impl Retry { /// /// #[tokio::main] /// async fn main() -> Result<(), MyError> { -/// -/// let (tx, rx) = flume::bounded(3); +/// +/// let (tx, mut rx) = mpsc::channel(3); /// /// for i in 0..3 { -/// tx.send(i).unwrap(); +/// tx.send(i).await.unwrap(); /// } /// retry_async!(Retry::default(), (async { -/// fallable_fn(&rx.clone()).await +/// fallable_fn(&mut rx).await /// })) /// } /// ``` @@ -211,6 +217,7 @@ impl RetryableError for xmtp_proto::api_client::Error { mod tests { use super::*; use thiserror::Error; + use tokio::sync::mpsc; #[derive(Debug, Error)] enum SomeError { @@ -291,8 +298,8 @@ mod tests { #[tokio::test] async fn it_works_async() { - async fn retryable_async_fn(rx: &flume::Receiver) -> Result<(), SomeError> { - let val = rx.recv_async().await.unwrap(); + async fn retryable_async_fn(rx: &mut mpsc::Receiver) -> Result<(), SomeError> { + let val = rx.recv().await.unwrap(); if val == 2 { return Ok(()); } @@ -301,14 +308,14 @@ mod tests { Err(SomeError::ARetryableError) } - let (tx, rx) = flume::bounded(3); + let (tx, mut rx) = mpsc::channel(3); for i in 0..3 { - tx.send(i).unwrap(); + tx.send(i).await.unwrap(); } retry_async!( Retry::default(), - (async { retryable_async_fn(&rx.clone()).await }) + (async { retryable_async_fn(&mut rx).await }) ) .unwrap(); assert!(rx.is_empty()); diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 6c3fd03e1..13cfc9657 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, pin::Pin, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use futures::{FutureExt, Stream, StreamExt}; use prost::Message; @@ -67,7 +67,7 @@ impl From for (Vec, MessagesStreamInfo) { impl Client where - ApiClient: XmtpApi, + ApiClient: XmtpApi + 'static, { async fn process_streamed_welcome( &self, @@ -130,7 +130,7 @@ where pub async fn stream_conversations( &self, - ) -> Result + Send + '_>>, ClientError> { + ) -> Result + '_, ClientError> { let event_queue = tokio_stream::wrappers::BroadcastStream::new(self.local_events.subscribe()); @@ -167,48 +167,50 @@ where } }); - Ok(Box::pin(futures::stream::select(stream, event_queue))) + Ok(futures::stream::select(stream, event_queue)) } #[tracing::instrument(skip(self, group_id_to_info))] pub(crate) async fn stream_messages( - self: Arc, - group_id_to_info: HashMap, MessagesStreamInfo>, - ) -> Result + Send>>, ClientError> { + &self, + group_id_to_info: Arc, MessagesStreamInfo>>, + ) -> Result + '_, ClientError> + where + ApiClient: 'static, + { let filters: Vec = group_id_to_info .iter() .map(|(group_id, info)| GroupFilter::new(group_id.clone(), Some(info.cursor))) .collect(); + let messages_subscription = self.api_client.subscribe_group_messages(filters).await?; let stream = messages_subscription .map(move |res| { - let context = self.context.clone(); - let client = self.clone(); - - let group_id_to_info = group_id_to_info.clone(); + let group_info = group_id_to_info.clone(); async move { match res { Ok(envelope) => { log::info!("Received message streaming payload"); let group_id = extract_group_id(&envelope)?; log::info!("Extracted group id {}", hex::encode(&group_id)); - let stream_info = group_id_to_info.get(&group_id).ok_or( + let stream_info = group_info.get(&group_id).ok_or( ClientError::StreamInconsistency( "Received message for a non-subscribed group".to_string(), ), )?; - let mls_group = - MlsGroup::new(context, group_id, stream_info.convo_created_at_ns); - mls_group - .process_stream_entry(envelope.clone(), client.clone()) - .await + let mls_group = MlsGroup::new( + self.context.clone(), + group_id, + stream_info.convo_created_at_ns, + ); + mls_group.process_stream_entry(envelope, self).await } Err(err) => Err(GroupError::Api(err)), } } }) - .filter_map(move |res| async { + .filter_map(|res| async { match res.await { Ok(Some(message)) => Some(message), Ok(None) => { @@ -221,14 +223,13 @@ where } } }); - - Ok(Box::pin(stream)) + Ok(stream) } } impl Client where - ApiClient: XmtpApi, + ApiClient: XmtpApi + 'static, { pub fn stream_conversations_with_callback( client: Arc>, @@ -237,7 +238,8 @@ where let (tx, rx) = oneshot::channel(); let handle = tokio::spawn(async move { - let mut stream = client.stream_conversations().await?; + let stream = client.stream_conversations().await?; + futures::pin_mut!(stream); let _ = tx.send(()); while let Some(convo) = stream.next().await { convo_callback(convo) @@ -258,9 +260,11 @@ where ) -> StreamHandle> { let (tx, rx) = oneshot::channel(); + let client = client.clone(); let handle = tokio::spawn(async move { - let mut stream = Self::stream_messages(client, group_id_to_info).await?; + let stream = Self::stream_messages(&client, group_id_to_info.into()).await?; let _ = tx.send(()); + futures::pin_mut!(stream); while let Some(message) = stream.next().await { callback(message) } @@ -274,11 +278,11 @@ where } pub async fn stream_all_messages( - client: Arc>, - ) -> Result>, ClientError> { - client.sync_welcomes().await?; + &self, + ) -> Result> + '_, ClientError> { + self.sync_welcomes().await?; - let mut group_id_to_info = client + let mut group_id_to_info = self .store() .conn()? .find_groups(None, None, None, None)? @@ -287,12 +291,14 @@ where .collect::, MessagesStreamInfo>>(); let stream = async_stream::stream! { - let client = client.clone(); - let mut messages_stream = client - .clone() - .stream_messages(group_id_to_info.clone()) + let messages_stream = self + .stream_messages(Arc::new(group_id_to_info.clone())) .await?; - let mut convo_stream = Self::stream_conversations(&client).await?; + futures::pin_mut!(messages_stream); + + let convo_stream = self.stream_conversations().await?; + futures::pin_mut!(convo_stream); + let mut extra_messages = Vec::new(); loop { @@ -316,7 +322,6 @@ where if group_id_to_info.contains_key(&new_group.group_id) { continue; } - for info in group_id_to_info.values_mut() { info.cursor = 0; } @@ -327,8 +332,7 @@ where cursor: 1, // For the new group, stream all messages since the group was created }, ); - - let new_messages_stream = match client.clone().stream_messages(group_id_to_info.clone()).await { + let new_messages_stream = match self.stream_messages(Arc::new(group_id_to_info.clone())).await { Ok(stream) => stream, Err(e) => { log::error!("{}", e); @@ -341,13 +345,13 @@ where while let Some(Some(message)) = messages_stream.next().now_or_never() { extra_messages.push(message); } - let _ = std::mem::replace(&mut messages_stream, new_messages_stream); + messages_stream.set(new_messages_stream); }, } } }; - Ok(Box::pin(stream)) + Ok(stream) } pub fn stream_all_messages_with_callback( @@ -357,8 +361,9 @@ where let (tx, rx) = oneshot::channel(); let handle = tokio::spawn(async move { - let mut stream = Self::stream_all_messages(client).await?; + let stream = Self::stream_all_messages(&client).await?; let _ = tx.send(()); + futures::pin_mut!(stream); while let Some(message) = stream.next().await { match message { Ok(m) => callback(m), @@ -408,7 +413,8 @@ mod tests { let mut stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); let bob_ptr = bob.clone(); tokio::spawn(async move { - let mut bob_stream = bob_ptr.stream_conversations().await.unwrap(); + let bob_stream = bob_ptr.stream_conversations().await.unwrap(); + futures::pin_mut!(bob_stream); while let Some(item) = bob_stream.next().await { let _ = tx.send(item); } @@ -445,7 +451,8 @@ mod tests { let notify_ptr = notify.clone(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); tokio::spawn(async move { - let mut stream = alice_group.stream(alice).await.unwrap(); + let stream = alice_group.stream(&alice).await.unwrap(); + futures::pin_mut!(stream); while let Some(item) = stream.next().await { let _ = tx.send(item); notify_ptr.notify_one(); diff --git a/xmtp_mls/src/utils/bench.rs b/xmtp_mls/src/utils/bench.rs index 61ac2d590..5d89e405d 100644 --- a/xmtp_mls/src/utils/bench.rs +++ b/xmtp_mls/src/utils/bench.rs @@ -4,7 +4,6 @@ #![allow(clippy::unwrap_used)] use crate::{builder::ClientBuilder, Client}; -use ethers::signers::{LocalWallet, Signer}; use indicatif::{ProgressBar, ProgressStyle}; use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; @@ -18,7 +17,8 @@ use tracing_subscriber::{ util::SubscriberInitExt, EnvFilter, }; -use xmtp_cryptography::utils::rng; +use xmtp_cryptography::utils::generate_local_wallet; +use xmtp_id::InboxOwner; use super::test::TestClient; @@ -126,13 +126,13 @@ impl Identity { } async fn create_identity(is_dev_network: bool) -> Identity { - let wallet = LocalWallet::new(&mut rng()); + let wallet = generate_local_wallet(); let client = if is_dev_network { ClientBuilder::new_dev_client(&wallet).await } else { ClientBuilder::new_test_client(&wallet).await }; - Identity::new(client.inbox_id(), format!("0x{:x}", wallet.address())) + Identity::new(client.inbox_id(), wallet.get_address()) } async fn create_identities(n: usize, is_dev_network: bool) -> Vec { diff --git a/xmtp_mls/src/utils/test.rs b/xmtp_mls/src/utils/test.rs index 5fc724d5e..2ca1378d8 100755 --- a/xmtp_mls/src/utils/test.rs +++ b/xmtp_mls/src/utils/test.rs @@ -49,7 +49,6 @@ pub fn rand_time() -> i64 { rng.gen_range(0..1_000_000_000) } -#[async_trait::async_trait] #[cfg(feature = "http-api")] impl XmtpTestClient for XmtpHttpApiClient { async fn create_local() -> Self { @@ -61,7 +60,6 @@ impl XmtpTestClient for XmtpHttpApiClient { } } -#[async_trait::async_trait] impl XmtpTestClient for GrpcClient { async fn create_local() -> Self { GrpcClient::create("http://localhost:5556".into(), false) diff --git a/xmtp_proto/Cargo.toml b/xmtp_proto/Cargo.toml index fa6da32e0..6217214fc 100644 --- a/xmtp_proto/Cargo.toml +++ b/xmtp_proto/Cargo.toml @@ -4,7 +4,6 @@ name = "xmtp_proto" version.workspace = true [dependencies] -async-trait = { workspace = true } futures = { workspace = true } futures-core = { workspace = true } pbjson-types.workspace = true @@ -15,6 +14,7 @@ prost-types = { workspace = true } serde = { workspace = true } openmls_basic_credential = { workspace = true, optional = true } openmls = { workspace = true, optional = true } +trait-variant = "0.1.2" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tonic = { workspace = true } @@ -39,4 +39,4 @@ proto_full = ["xmtp-identity","xmtp-identity-api-v1","xmtp-identity-associations "xmtp-mls-message_contents" = [] "xmtp-mls_validation-v1" = ["xmtp-identity-associations"] "xmtp-xmtpv4" = ["xmtp-identity-associations","xmtp-mls-api-v1"] -## @@protoc_insertion_point(features) \ No newline at end of file +## @@protoc_insertion_point(features) diff --git a/xmtp_proto/src/api_client.rs b/xmtp_proto/src/api_client.rs index cd0dd6160..60db8e4e9 100644 --- a/xmtp_proto/src/api_client.rs +++ b/xmtp_proto/src/api_client.rs @@ -1,7 +1,6 @@ use std::{error::Error as StdError, fmt}; -use async_trait::async_trait; -use futures::{stream, Stream}; +use futures::Stream; pub use super::xmtp::message_api::v1::{ BatchQueryRequest, BatchQueryResponse, Envelope, PagingInfo, PublishRequest, PublishResponse, @@ -103,22 +102,28 @@ pub trait XmtpApiSubscription { fn close_stream(&mut self); } -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +#[allow(async_fn_in_trait)] pub trait MutableApiSubscription: Stream> + Send { async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error>; fn close(&self); } -pub trait ClientWithMetadata: Send + Sync { +pub trait ClientWithMetadata { fn set_libxmtp_version(&mut self, version: String) -> Result<(), Error>; fn set_app_version(&mut self, version: String) -> Result<(), Error>; } +/// Global Marker trait for WebAssembly +#[cfg(target_arch = "wasm32")] +pub trait Wasm {} +#[cfg(target_arch = "wasm32")] +impl Wasm for T {} + // Wasm futures don't have `Send` or `Sync` bounds. -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait XmtpApiClient: Send + Sync { +#[allow(async_fn_in_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpApiClient: Send))] +#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpApiClient: Wasm))] +pub trait LocalXmtpApiClient { type Subscription: XmtpApiSubscription; type MutableSubscription: MutableApiSubscription; @@ -140,20 +145,11 @@ pub trait XmtpApiClient: Send + Sync { async fn batch_query(&self, request: BatchQueryRequest) -> Result; } -#[cfg(not(target_arch = "wasm32"))] -pub type GroupMessageStream = stream::BoxStream<'static, Result>; -#[cfg(target_arch = "wasm32")] -pub type GroupMessageStream = stream::LocalBoxStream<'static, Result>; - -#[cfg(not(target_arch = "wasm32"))] -pub type WelcomeMessageStream = stream::BoxStream<'static, Result>; -#[cfg(target_arch = "wasm32")] -pub type WelcomeMessageStream = stream::LocalBoxStream<'static, Result>; - // Wasm futures don't have `Send` or `Sync` bounds. -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait XmtpMlsClient: Send + Sync + 'static { +#[allow(async_fn_in_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpMlsClient: Send))] +#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpMlsClient: Wasm))] +pub trait LocalXmtpMlsClient { async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; async fn fetch_key_packages( &self, @@ -170,20 +166,57 @@ pub trait XmtpMlsClient: Send + Sync + 'static { &self, request: QueryWelcomeMessagesRequest, ) -> Result; +} + +#[allow(async_fn_in_trait)] +#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpMlsStreams: Wasm))] +pub trait LocalXmtpMlsStreams { + type GroupMessageStream<'a>: Stream> + 'a + where + Self: 'a; + + type WelcomeMessageStream<'a>: Stream> + 'a + where + Self: 'a; + async fn subscribe_group_messages( &self, request: SubscribeGroupMessagesRequest, - ) -> Result; + ) -> Result, Error>; async fn subscribe_welcome_messages( &self, request: SubscribeWelcomeMessagesRequest, - ) -> Result; + ) -> Result, Error>; } -// Wasm futures don't have `Send` or `Sync` bounds. -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait XmtpIdentityClient: Send + Sync + 'static { +// we manually make a Local+Non-Local trait variant here b/c the +// macro breaks with GATs +#[allow(async_fn_in_trait)] +#[cfg(not(target_arch = "wasm32"))] +pub trait XmtpMlsStreams: Send { + type GroupMessageStream<'a>: Stream> + Send + 'a + where + Self: 'a; + + type WelcomeMessageStream<'a>: Stream> + Send + 'a + where + Self: 'a; + + fn subscribe_group_messages( + &self, + request: SubscribeGroupMessagesRequest, + ) -> impl futures::Future, Error>> + Send; + + fn subscribe_welcome_messages( + &self, + request: SubscribeWelcomeMessagesRequest, + ) -> impl futures::Future, Error>> + Send; +} + +#[allow(async_fn_in_trait)] +#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpIdentityClient: Send))] +#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpIdentityClient: Wasm))] +pub trait LocalXmtpIdentityClient { async fn publish_identity_update( &self, request: PublishIdentityUpdateRequest,