diff --git a/core/src/client/async_client/mod.rs b/core/src/client/async_client/mod.rs index 9a6e6e59f4..9e39fb8053 100644 --- a/core/src/client/async_client/mod.rs +++ b/core/src/client/async_client/mod.rs @@ -33,8 +33,8 @@ mod utils; use crate::client::async_client::helpers::{process_subscription_close_response, InnerBatchResponse}; use crate::client::async_client::utils::MaybePendingFutures; use crate::client::{ - BatchMessage, BatchResponse, ClientT, ReceivedMessage, RegisterNotificationMessage, RequestMessage, - Subscription, SubscriptionClientT, SubscriptionKind, SubscriptionMessage, TransportReceiverT, TransportSenderT, Error + BatchMessage, BatchResponse, ClientT, Error, ReceivedMessage, RegisterNotificationMessage, RequestMessage, + Subscription, SubscriptionClientT, SubscriptionKind, SubscriptionMessage, TransportReceiverT, TransportSenderT, }; use crate::error::RegisterMethodError; use crate::params::{BatchRequestBuilder, EmptyBatchRequest}; @@ -64,7 +64,7 @@ use serde::de::DeserializeOwned; use tokio::sync::{mpsc, oneshot}; use tracing::instrument; -use self::utils::{IntervalStream, InactivityCheck}; +use self::utils::{InactivityCheck, IntervalStream}; use super::{generate_batch_id_range, FrontToBack, IdKind, RequestIdManager}; @@ -94,11 +94,7 @@ pub struct PingConfig { impl Default for PingConfig { fn default() -> Self { - Self { - ping_interval: Duration::from_secs(30), - max_failures: 1, - inactive_limit: Duration::from_secs(40), - } + Self { ping_interval: Duration::from_secs(30), max_failures: 1, inactive_limit: Duration::from_secs(40) } } } @@ -126,9 +122,9 @@ impl PingConfig { /// Configure how many times the connection is allowed be /// inactive until the connection is closed. - /// + /// /// # Panics - /// + /// /// This method panics if `max` == 0. pub fn max_failures(mut self, max: usize) -> Self { assert!(max > 0); @@ -137,7 +133,6 @@ impl PingConfig { } } - #[derive(Debug, Default, Clone)] pub(crate) struct ThreadSafeRequestManager(Arc>); @@ -179,7 +174,9 @@ impl ErrorFromBack { // This should never happen because the receiving end is still alive. // Before shutting down the background task a error message should // be emitted. - Err(_) => Error::Custom("Error reason could not be found. This is a bug. Please open an issue.".to_string()), + Err(_) => Error::Custom( + "Error reason could not be found. This is a bug. Please open an issue.".to_string(), + ), }); *write = Some(ReadErrorOnce::Read(arc_err.clone())); arc_err @@ -281,7 +278,7 @@ impl ClientBuilder { } /// Enable WebSocket ping/pong on the client. - /// + /// /// This only works if the transport supports WebSocket pings. /// /// Default: pings are disabled. @@ -332,11 +329,16 @@ impl ClientBuilder { Some(p) => { // NOTE: This emits a tick immediately to sync how the `inactive_interval` works // because it starts measuring when the client start-ups. - let ping_interval = IntervalStream::new(tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(p.ping_interval))); + let ping_interval = IntervalStream::new(tokio_stream::wrappers::IntervalStream::new( + tokio::time::interval(p.ping_interval), + )); - let inactive_interval = { + let inactive_interval = { let start = tokio::time::Instant::now() + p.inactive_limit; - IntervalStream::new(tokio_stream::wrappers::IntervalStream::new(tokio::time::interval_at(start, p.inactive_limit))) + IntervalStream::new(tokio_stream::wrappers::IntervalStream::new(tokio::time::interval_at( + start, + p.inactive_limit, + ))) }; let inactivity_check = InactivityCheck::new(p.inactive_limit, p.max_failures); @@ -386,8 +388,8 @@ impl ClientBuilder { { use futures_util::stream::Pending; - type PendingIntervalStream = IntervalStream>; - + type PendingIntervalStream = IntervalStream>; + let (to_back, from_front) = mpsc::channel(self.max_concurrent_requests); let (err_to_front, err_from_back) = oneshot::channel::(); let max_buffer_capacity_per_subscription = self.max_buffer_capacity_per_subscription; @@ -466,12 +468,12 @@ impl Client { /// This is similar to [`Client::on_disconnect`] but it can be used to get /// the reason why the client was disconnected but it's not cancel-safe. - /// + /// /// The typical use-case is that this method will be called after /// [`Client::on_disconnect`] has returned in a "select loop". - /// + /// /// # Cancel-safety - /// + /// /// This method is not cancel-safe pub async fn disconnect_reason(&self) -> Error { self.error.read_error().await @@ -554,7 +556,7 @@ impl ClientT for Client { Err(_) => return Err(self.disconnect_reason().await), }; - rx_log_from_json(&Response::new(ResponsePayload::result_borrowed(&json_value), id), self.max_log_length); + rx_log_from_json(&Response::new(ResponsePayload::success_borrowed(&json_value), id), self.max_log_length); serde_json::from_value(json_value).map_err(Error::ParseError) } @@ -643,9 +645,7 @@ impl SubscriptionClientT for Client { Notif: DeserializeOwned, { if subscribe_method == unsubscribe_method { - return Err(RegisterMethodError::SubscriptionNameConflict( - unsubscribe_method.to_owned(), - ).into()); + return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into()); } let guard = self.id_manager.next_request_two_ids()?; @@ -680,7 +680,7 @@ impl SubscriptionClientT for Client { Err(_) => return Err(self.disconnect_reason().await), }; - rx_log_from_json(&Response::new(ResponsePayload::result_borrowed(&sub_id), id_unsub), self.max_log_length); + rx_log_from_json(&Response::new(ResponsePayload::success_borrowed(&sub_id), id_unsub), self.max_log_length); Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id))) } @@ -901,8 +901,7 @@ async fn handle_frontend_messages( if manager.lock().insert_notification_handler(®.method, subscribe_tx).is_ok() { let _ = reg.send_back.send(Ok((subscribe_rx, reg.method))); } else { - let _ = - reg.send_back.send(Err(RegisterMethodError::AlreadyRegistered(reg.method).into())); + let _ = reg.send_back.send(Err(RegisterMethodError::AlreadyRegistered(reg.method).into())); } } // User dropped the NotificationHandler for this method @@ -950,30 +949,30 @@ where // This is safe because `tokio::time::Interval`, `tokio::mpsc::Sender` and `tokio::mpsc::Receiver` // are cancel-safe. - let res = loop { - tokio::select! { - biased; - _ = close_tx.closed() => break Ok(()), - maybe_msg = from_frontend.recv() => { - let Some(msg) = maybe_msg else { - break Ok(()); - }; - - if let Err(e) = - handle_frontend_messages(msg, &manager, &mut sender, max_buffer_capacity_per_subscription).await - { - tracing::error!(target: LOG_TARGET, "ws send failed: {e}"); - break Err(Error::Transport(e.into())); - } + let res = loop { + tokio::select! { + biased; + _ = close_tx.closed() => break Ok(()), + maybe_msg = from_frontend.recv() => { + let Some(msg) = maybe_msg else { + break Ok(()); + }; + + if let Err(e) = + handle_frontend_messages(msg, &manager, &mut sender, max_buffer_capacity_per_subscription).await + { + tracing::error!(target: LOG_TARGET, "ws send failed: {e}"); + break Err(Error::Transport(e.into())); } - _ = ping_interval.next() => { - if let Err(err) = sender.send_ping().await { - tracing::error!(target: LOG_TARGET, "Send ws ping failed: {err}"); - break Err(Error::Transport(err.into())); - } + } + _ = ping_interval.next() => { + if let Err(err) = sender.send_ping().await { + tracing::error!(target: LOG_TARGET, "Send ws ping failed: {err}"); + break Err(Error::Transport(err.into())); } } - }; + } + }; from_frontend.close(); let _ = sender.close().await; @@ -995,7 +994,15 @@ where R: TransportReceiverT, S: Stream + Unpin, { - let ReadTaskParams { receiver, close_tx, to_send_task, manager, max_buffer_capacity_per_subscription, mut inactivity_check, mut inactivity_stream } = params; + let ReadTaskParams { + receiver, + close_tx, + to_send_task, + manager, + max_buffer_capacity_per_subscription, + mut inactivity_check, + mut inactivity_stream, + } = params; let backend_event = futures_util::stream::unfold(receiver, |mut receiver| async { let res = receiver.receive().await; diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index 9b121db775..9f444aa3af 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -27,16 +27,9 @@ use std::io; use std::time::Duration; -use jsonrpsee_types::error::{ - reject_too_big_batch_response, ErrorCode, ErrorObject, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, -}; -use jsonrpsee_types::{Id, InvalidRequest, Response, ResponsePayload}; -use serde::Serialize; -use serde_json::value::to_raw_value; +use jsonrpsee_types::{ErrorCode, ErrorObject, Id, InvalidRequest, Response, ResponsePayload}; use tokio::sync::mpsc; -use crate::server::LOG_TARGET; - use super::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError}; /// Bounded writer that allows writing at most `max_len` bytes. @@ -100,7 +93,7 @@ impl MethodSink { } /// Create a new `MethodSink` with a limited response size. - pub fn new_with_limit(tx: mpsc::Sender, max_response_size: u32 ) -> Self { + pub fn new_with_limit(tx: mpsc::Sender, max_response_size: u32) -> Self { MethodSink { tx, max_response_size } } @@ -138,8 +131,8 @@ impl MethodSink { /// Send a JSON-RPC error to the client pub async fn send_error<'a>(&self, id: Id<'a>, err: ErrorObject<'a>) -> Result<(), DisconnectError> { - let json = - serde_json::to_string(&Response::new(ResponsePayload::<()>::Error(err), id)).expect("valid JSON; qed"); + let payload = ResponsePayload::<()>::error_borrowed(err); + let json = serde_json::to_string(&Response::new(payload, id)).expect("valid JSON; qed"); self.send(json).await } @@ -169,221 +162,15 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) { } } -/// Represents a response to a method call. -/// -/// NOTE: A subscription is also a method call but it's -/// possible determine whether a method response -/// is "subscription" or "ordinary method call" -/// by calling [`MethodResponse::is_subscription`] -#[derive(Debug, Clone)] -pub struct MethodResponse { - /// Serialized JSON-RPC response, - pub result: String, - /// Indicates whether the call was successful or not. - pub success_or_error: MethodResponseResult, - /// Indicates whether the call was a subscription response. - pub is_subscription: bool, -} - -impl MethodResponse { - /// Returns whether the call was successful. - pub fn is_success(&self) -> bool { - self.success_or_error.is_success() - } - - /// Returns whether the call failed. - pub fn is_error(&self) -> bool { - self.success_or_error.is_success() - } - - /// Returns whether the call is a subscription. - pub fn is_subscription(&self) -> bool { - self.is_subscription - } -} - -/// Represent the outcome of a method call success or failed. -#[derive(Debug, Copy, Clone)] -pub enum MethodResponseResult { - /// The method call was successful. - Success, - /// The method call failed with error code. - Failed(i32), -} - -impl MethodResponseResult { - /// Returns whether the call was successful. - pub fn is_success(&self) -> bool { - matches!(self, MethodResponseResult::Success) - } - - /// Returns whether the call failed. - pub fn is_error(&self) -> bool { - matches!(self, MethodResponseResult::Failed(_)) - } - - /// Get the error code - /// - /// Returns `Some(error code)` if the call failed. - pub fn as_error_code(&self) -> Option { - match self { - Self::Failed(e) => Some(*e), - _ => None, - } - } -} - -impl MethodResponse { - /// This is similar to [`MethodResponse::response`] but sets a flag to indicate - /// that response is a subscription. - pub fn subscription_response(id: Id, result: ResponsePayload, max_response_size: usize) -> Self - where - T: Serialize + Clone, - { - let mut rp = Self::response(id, result, max_response_size); - rp.is_subscription = true; - rp - } - - /// Create a new method response. - /// - /// If the serialization of `result` exceeds `max_response_size` then - /// the response is changed to an JSON-RPC error object. - pub fn response(id: Id, result: ResponsePayload, max_response_size: usize) -> Self - where - T: Serialize + Clone, - { - let mut writer = BoundedWriter::new(max_response_size); - - let success_or_error = if let ResponsePayload::Error(ref e) = result { - MethodResponseResult::Failed(e.code()) - } else { - MethodResponseResult::Success - }; - - match serde_json::to_writer(&mut writer, &Response::new(result, id.clone())) { - Ok(_) => { - // Safety - serde_json does not emit invalid UTF-8. - let result = unsafe { String::from_utf8_unchecked(writer.into_bytes()) }; - - Self { result, success_or_error, is_subscription: false } - } - Err(err) => { - tracing::error!(target: LOG_TARGET, "Error serializing response: {:?}", err); - - if err.is_io() { - let data = to_raw_value(&format!("Exceeded max limit of {max_response_size}")).ok(); - let err_code = OVERSIZED_RESPONSE_CODE; - - let err = ResponsePayload::error_borrowed(ErrorObject::borrowed( - err_code, - OVERSIZED_RESPONSE_MSG, - data.as_deref(), - )); - let result = - serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); - - Self { result, success_or_error: MethodResponseResult::Failed(err_code), is_subscription: false } - } else { - let err_code = ErrorCode::InternalError; - let result = serde_json::to_string(&Response::new(err_code.into(), id)) - .expect("JSON serialization infallible; qed"); - Self { - result, - success_or_error: MethodResponseResult::Failed(err_code.code()), - is_subscription: false, - } - } - } - } - } - - /// This is similar to [`MethodResponse::error`] but sets a flag to indicate - /// that error is a subscription. - pub fn subscription_error<'a>(id: Id, err: impl Into>) -> Self { - let mut rp = Self::error(id, err); - rp.is_subscription = true; - rp - } - - /// Create a [`MethodResponse`] from a JSON-RPC error. - pub fn error<'a>(id: Id, err: impl Into>) -> Self { - let err: ErrorObject = err.into(); - let err_code = err.code(); - let err = ResponsePayload::error_borrowed(err); - let result = serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); - Self { result, success_or_error: MethodResponseResult::Failed(err_code), is_subscription: false } - } -} - -/// Builder to build a `BatchResponse`. -#[derive(Debug, Clone, Default)] -pub struct BatchResponseBuilder { - /// Serialized JSON-RPC response, - result: String, - /// Max limit for the batch - max_response_size: usize, -} - -impl BatchResponseBuilder { - /// Create a new batch response builder with limit. - pub fn new_with_limit(limit: usize) -> Self { - let mut initial = String::with_capacity(2048); - initial.push('['); - - Self { result: initial, max_response_size: limit } - } - - /// Append a result from an individual method to the batch response. - /// - /// Fails if the max limit is exceeded and returns to error response to - /// return early in order to not process method call responses which are thrown away anyway. - pub fn append(&mut self, response: &MethodResponse) -> Result<(), MethodResponse> { - // `,` will occupy one extra byte for each entry - // on the last item the `,` is replaced by `]`. - let len = response.result.len() + self.result.len() + 1; - - if len > self.max_response_size { - Err(MethodResponse::error(Id::Null, reject_too_big_batch_response(self.max_response_size))) - } else { - self.result.push_str(&response.result); - self.result.push(','); - Ok(()) - } - } - - /// Check if the batch is empty. - pub fn is_empty(&self) -> bool { - self.result.len() <= 1 - } - - /// Finish the batch response - pub fn finish(mut self) -> String { - if self.result.len() == 1 { - batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)) - } else { - self.result.pop(); - self.result.push(']'); - self.result - } - } -} - -/// Create a JSON-RPC error response. -pub fn batch_response_error(id: Id, err: impl Into>) -> String { - let err = ResponsePayload::error_borrowed(err); - serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed") -} - #[cfg(test)] mod tests { - use super::{BatchResponseBuilder, BoundedWriter, Id, MethodResponse, Response}; - use jsonrpsee_types::ResponsePayload; + use crate::server::BoundedWriter; + use jsonrpsee_types::{Id, Response, ResponsePayload}; #[test] fn bounded_serializer_work() { let mut writer = BoundedWriter::new(100); - let result = ResponsePayload::result(&"success"); + let result = ResponsePayload::success(&"success"); let rp = &Response::new(result, Id::Number(1)); assert!(serde_json::to_writer(&mut writer, rp).is_ok()); @@ -396,52 +183,4 @@ mod tests { // NOTE: `"` is part of the serialization so 101 characters. assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err()); } - - #[test] - fn batch_with_single_works() { - let method = MethodResponse::response(Id::Number(1), ResponsePayload::result_borrowed(&"a"), usize::MAX); - assert_eq!(method.result.len(), 37); - - // Recall a batch appends two bytes for the `[]`. - let mut builder = BatchResponseBuilder::new_with_limit(39); - builder.append(&method).unwrap(); - let batch = builder.finish(); - - assert_eq!(batch, r#"[{"jsonrpc":"2.0","result":"a","id":1}]"#) - } - - #[test] - fn batch_with_multiple_works() { - let m1 = MethodResponse::response(Id::Number(1), ResponsePayload::result_borrowed(&"a"), usize::MAX); - assert_eq!(m1.result.len(), 37); - - // Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call. - // so it should be 2 + (37 * n) + (n-1) - let limit = 2 + (37 * 2) + 1; - let mut builder = BatchResponseBuilder::new_with_limit(limit); - builder.append(&m1).unwrap(); - builder.append(&m1).unwrap(); - let batch = builder.finish(); - - assert_eq!(batch, r#"[{"jsonrpc":"2.0","result":"a","id":1},{"jsonrpc":"2.0","result":"a","id":1}]"#) - } - - #[test] - fn batch_empty_err() { - let batch = BatchResponseBuilder::new_with_limit(1024).finish(); - - let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#; - assert_eq!(batch, exp_err); - } - - #[test] - fn batch_too_big() { - let method = MethodResponse::response(Id::Number(1), ResponsePayload::result_borrowed(&"a".repeat(28)), 128); - assert_eq!(method.result.len(), 64); - - let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err(); - - let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32011,"message":"The batch response was too large","data":"Exceeded max limit of 63"},"id":null}"#; - assert_eq!(batch.result, exp_err); - } } diff --git a/core/src/server/method_response.rs b/core/src/server/method_response.rs new file mode 100644 index 0000000000..e4965a2937 --- /dev/null +++ b/core/src/server/method_response.rs @@ -0,0 +1,494 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use crate::server::{BoundedWriter, LOG_TARGET}; +use std::task::Poll; + +use futures_util::{Future, FutureExt}; +use jsonrpsee_types::error::{ + reject_too_big_batch_response, ErrorCode, ErrorObject, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, +}; +use jsonrpsee_types::{ErrorObjectOwned, Id, Response, ResponsePayload as InnerResponsePayload}; +use serde::Serialize; +use serde_json::value::to_raw_value; + +#[derive(Debug, Clone)] +enum ResponseKind { + MethodCall, + Subscription, + Batch, +} + +/// Represents a response to a method call. +/// +/// NOTE: A subscription is also a method call but it's +/// possible determine whether a method response +/// is "subscription" or "ordinary method call" +/// by calling [`MethodResponse::is_subscription`] +#[derive(Debug)] +pub struct MethodResponse { + /// Serialized JSON-RPC response, + result: String, + /// Indicates whether the call was successful or not. + success_or_error: MethodResponseResult, + /// Indicates whether the call was a subscription response. + kind: ResponseKind, + /// Optional callback that may be utilized to notif + /// that the method response has been processed + on_close: Option, +} + +impl MethodResponse { + /// Returns whether the call was successful. + pub fn is_success(&self) -> bool { + self.success_or_error.is_success() + } + + /// Returns whether the call failed. + pub fn is_error(&self) -> bool { + self.success_or_error.is_error() + } + + /// Returns whether the response is a subscription response. + pub fn is_subscription(&self) -> bool { + matches!(self.kind, ResponseKind::Subscription) + } + + /// Returns whether the response is a method response. + pub fn is_method_call(&self) -> bool { + matches!(self.kind, ResponseKind::MethodCall) + } + + /// Returns whether the response is a batch response. + pub fn is_batch(&self) -> bool { + matches!(self.kind, ResponseKind::Batch) + } + + /// Consume the method response and extract the serialized response. + pub fn into_result(self) -> String { + self.result + } + + /// Extract the serialized response as a String. + pub fn to_result(&self) -> String { + self.result.clone() + } + + /// Consume the method response and extract the parts. + pub fn into_parts(self) -> (String, Option) { + (self.result, self.on_close) + } + + /// Get the error code + /// + /// Returns `Some(error code)` if the call failed. + pub fn as_error_code(&self) -> Option { + self.success_or_error.as_error_code() + } + + /// Get a reference to the serialized response. + pub fn as_result(&self) -> &str { + &self.result + } + + /// Create a method response from [`BatchResponse`]. + pub fn from_batch(batch: BatchResponse) -> Self { + Self { + result: batch.0, + success_or_error: MethodResponseResult::Success, + kind: ResponseKind::Batch, + on_close: None, + } + } + + /// This is similar to [`MethodResponse::response`] but sets a flag to indicate + /// that response is a subscription. + pub fn subscription_response(id: Id, result: ResponsePayload, max_response_size: usize) -> Self + where + T: Serialize + Clone, + { + let mut rp = Self::response(id, result, max_response_size); + rp.kind = ResponseKind::Subscription; + rp + } + + /// Create a new method response. + /// + /// If the serialization of `result` exceeds `max_response_size` then + /// the response is changed to an JSON-RPC error object. + pub fn response(id: Id, rp: ResponsePayload, max_response_size: usize) -> Self + where + T: Serialize + Clone, + { + let mut writer = BoundedWriter::new(max_response_size); + + let success_or_error = if let InnerResponsePayload::Error(ref e) = rp.inner { + MethodResponseResult::Failed(e.code()) + } else { + MethodResponseResult::Success + }; + + let kind = ResponseKind::MethodCall; + + match serde_json::to_writer(&mut writer, &Response::new(rp.inner, id.clone())) { + Ok(_) => { + // Safety - serde_json does not emit invalid UTF-8. + let result = unsafe { String::from_utf8_unchecked(writer.into_bytes()) }; + + Self { result, success_or_error, kind, on_close: rp.on_exit } + } + Err(err) => { + tracing::error!(target: LOG_TARGET, "Error serializing response: {:?}", err); + + if err.is_io() { + let data = to_raw_value(&format!("Exceeded max limit of {max_response_size}")).ok(); + let err_code = OVERSIZED_RESPONSE_CODE; + + let err = InnerResponsePayload::<()>::error_borrowed(ErrorObject::borrowed( + err_code, + OVERSIZED_RESPONSE_MSG, + data.as_deref(), + )); + let result = + serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); + + Self { + result, + success_or_error: MethodResponseResult::Failed(err_code), + kind, + on_close: rp.on_exit, + } + } else { + let err = ErrorCode::InternalError; + let payload = jsonrpsee_types::ResponsePayload::<()>::error(err); + let result = + serde_json::to_string(&Response::new(payload, id)).expect("JSON serialization infallible; qed"); + Self { + result, + success_or_error: MethodResponseResult::Failed(err.code()), + kind, + on_close: rp.on_exit, + } + } + } + } + } + + /// This is similar to [`MethodResponse::error`] but sets a flag to indicate + /// that error is a subscription. + pub fn subscription_error<'a>(id: Id, err: impl Into>) -> Self { + let mut rp = Self::error(id, err); + rp.kind = ResponseKind::Subscription; + rp + } + + /// Create a [`MethodResponse`] from a JSON-RPC error. + pub fn error<'a>(id: Id, err: impl Into>) -> Self { + let err: ErrorObject = err.into(); + let err_code = err.code(); + let err = InnerResponsePayload::<()>::error_borrowed(err); + let result = serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed"); + Self { + result, + success_or_error: MethodResponseResult::Failed(err_code), + kind: ResponseKind::MethodCall, + on_close: None, + } + } +} + +/// Represent the outcome of a method call success or failed. +#[derive(Debug, Copy, Clone)] +enum MethodResponseResult { + /// The method call was successful. + Success, + /// The method call failed with error code. + Failed(i32), +} + +impl MethodResponseResult { + /// Returns whether the call was successful. + fn is_success(&self) -> bool { + matches!(self, MethodResponseResult::Success) + } + + /// Returns whether the call failed. + fn is_error(&self) -> bool { + matches!(self, MethodResponseResult::Failed(_)) + } + + /// Get the error code + /// + /// Returns `Some(error code)` if the call failed. + fn as_error_code(&self) -> Option { + match self { + Self::Failed(e) => Some(*e), + _ => None, + } + } +} + +/// Builder to build a `BatchResponse`. +#[derive(Debug, Clone, Default)] +pub struct BatchResponseBuilder { + /// Serialized JSON-RPC response, + result: String, + /// Max limit for the batch + max_response_size: usize, +} + +impl BatchResponseBuilder { + /// Create a new batch response builder with limit. + pub fn new_with_limit(limit: usize) -> Self { + let mut initial = String::with_capacity(2048); + initial.push('['); + + Self { result: initial, max_response_size: limit } + } + + /// Append a result from an individual method to the batch response. + /// + /// Fails if the max limit is exceeded and returns to error response to + /// return early in order to not process method call responses which are thrown away anyway. + pub fn append(&mut self, response: &MethodResponse) -> Result<(), MethodResponse> { + // `,` will occupy one extra byte for each entry + // on the last item the `,` is replaced by `]`. + let len = response.result.len() + self.result.len() + 1; + + if len > self.max_response_size { + Err(MethodResponse::error(Id::Null, reject_too_big_batch_response(self.max_response_size))) + } else { + self.result.push_str(&response.result); + self.result.push(','); + Ok(()) + } + } + + /// Check if the batch is empty. + pub fn is_empty(&self) -> bool { + self.result.len() <= 1 + } + + /// Finish the batch response + pub fn finish(mut self) -> BatchResponse { + if self.result.len() == 1 { + BatchResponse(batch_response_error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))) + } else { + self.result.pop(); + self.result.push(']'); + BatchResponse(self.result) + } + } +} + +/// Serialized batch response. +#[derive(Debug, Clone)] +pub struct BatchResponse(String); + +/// Create a JSON-RPC error response. +pub fn batch_response_error(id: Id, err: impl Into>) -> String { + let err = InnerResponsePayload::<()>::error_borrowed(err); + serde_json::to_string(&Response::new(err, id)).expect("JSON serialization infallible; qed") +} + +/// Similar to [`jsonrpsee_types::ResponsePayload`] but possible to with an async-like +/// API to detect when a method response has been sent. +#[derive(Debug)] +pub struct ResponsePayload<'a, T> +where + T: Clone, +{ + inner: InnerResponsePayload<'a, T>, + on_exit: Option, +} + +impl<'a, T: Clone> From> for ResponsePayload<'a, T> { + fn from(inner: InnerResponsePayload<'a, T>) -> Self { + Self { inner, on_exit: None } + } +} + +impl<'a, T> ResponsePayload<'a, T> +where + T: Clone, +{ + /// Create a successful owned response payload. + pub fn success(t: T) -> Self { + InnerResponsePayload::success(t).into() + } + + /// Create a successful borrowed response payload. + pub fn success_borrowed(t: &'a T) -> Self { + InnerResponsePayload::success_borrowed(t).into() + } + + /// Create an error response payload. + pub fn error(e: impl Into) -> Self { + InnerResponsePayload::error(e.into()).into() + } + + /// Create a borrowd error response payload. + pub fn error_borrowed(e: impl Into>) -> Self { + InnerResponsePayload::error_borrowed(e.into()).into() + } + + /// Consumes the [`ResponsePayload`] and produces new [`ResponsePayload`] and a future + /// [`MethodResponseFuture`] that will be resolved once the response has been processed. + /// + /// If this has been called more than once then this will overwrite + /// the old result the previous future(s) will be resolved with error. + pub fn notify_on_completion(mut self) -> (Self, MethodResponseFuture) { + let (tx, rx) = response_channel(); + self.on_exit = Some(tx); + (self, rx) + } + + /// Convert the response payload into owned. + pub fn into_owned(self) -> ResponsePayload<'static, T> { + ResponsePayload { inner: self.inner.into_owned(), on_exit: self.on_exit } + } +} + +impl<'a, T> From for ResponsePayload<'a, T> +where + T: Clone, +{ + fn from(code: ErrorCode) -> Self { + let err: ErrorObject = code.into(); + Self::error(err) + } +} + +/// Create a channel to be used in combination with [`ResponsePayload`] to +/// notify when a method call has been processed. +fn response_channel() -> (MethodResponseNotifyTx, MethodResponseFuture) { + let (tx, rx) = tokio::sync::oneshot::channel(); + (MethodResponseNotifyTx(tx), MethodResponseFuture(rx)) +} + +/// Sends a message once the method response has been processed. +#[derive(Debug)] +pub struct MethodResponseNotifyTx(tokio::sync::oneshot::Sender); + +impl MethodResponseNotifyTx { + /// Send a notify message. + pub fn notify(self, is_success: bool) { + let msg = if is_success { NotifyMsg::Ok } else { NotifyMsg::Err }; + _ = self.0.send(msg); + } +} + +/// Future that resolves when the method response has been processed. +#[derive(Debug)] +pub struct MethodResponseFuture(tokio::sync::oneshot::Receiver); + +/// A message that that tells whether notification +/// was succesful or not. +#[derive(Debug, Copy, Clone)] +pub enum NotifyMsg { + /// The response was succesfully processed. + Ok, + /// The response was the wrong kind + /// such an error response when + /// one expected a succesful response. + Err, +} + +/// Method response error. +#[derive(Debug, Copy, Clone)] +pub enum MethodResponseError { + /// The connection was closed. + Closed, + /// The response was a JSON-RPC error. + JsonRpcError, +} + +impl Future for MethodResponseFuture { + type Output = Result<(), MethodResponseError>; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match self.0.poll_unpin(cx) { + Poll::Ready(Ok(NotifyMsg::Ok)) => Poll::Ready(Ok(())), + Poll::Ready(Ok(NotifyMsg::Err)) => Poll::Ready(Err(MethodResponseError::JsonRpcError)), + Poll::Ready(Err(_)) => Poll::Ready(Err(MethodResponseError::Closed)), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod tests { + use super::{BatchResponseBuilder, MethodResponse, ResponsePayload}; + use jsonrpsee_types::Id; + + #[test] + fn batch_with_single_works() { + let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); + assert_eq!(method.result.len(), 37); + + // Recall a batch appends two bytes for the `[]`. + let mut builder = BatchResponseBuilder::new_with_limit(39); + builder.append(&method).unwrap(); + let batch = builder.finish(); + + assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","result":"a","id":1}]"#) + } + + #[test] + fn batch_with_multiple_works() { + let m1 = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a"), usize::MAX); + assert_eq!(m1.result.len(), 37); + + // Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call. + // so it should be 2 + (37 * n) + (n-1) + let limit = 2 + (37 * 2) + 1; + let mut builder = BatchResponseBuilder::new_with_limit(limit); + builder.append(&m1).unwrap(); + builder.append(&m1).unwrap(); + let batch = builder.finish(); + + assert_eq!(batch.0, r#"[{"jsonrpc":"2.0","result":"a","id":1},{"jsonrpc":"2.0","result":"a","id":1}]"#) + } + + #[test] + fn batch_empty_err() { + let batch = BatchResponseBuilder::new_with_limit(1024).finish(); + + let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#; + assert_eq!(batch.0, exp_err); + } + + #[test] + fn batch_too_big() { + let method = MethodResponse::response(Id::Number(1), ResponsePayload::success_borrowed(&"a".repeat(28)), 128); + assert_eq!(method.result.len(), 64); + + let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err(); + + let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32011,"message":"The batch response was too large","data":"Exceeded max limit of 63"},"id":null}"#; + assert_eq!(batch.result, exp_err); + } +} diff --git a/core/src/server/mod.rs b/core/src/server/mod.rs index 38a399c08a..46133c5afa 100644 --- a/core/src/server/mod.rs +++ b/core/src/server/mod.rs @@ -30,17 +30,20 @@ mod error; /// Helpers. pub mod helpers; +/// Method response related types. +mod method_response; /// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration. mod rpc_module; /// Subscription related types. mod subscription; pub use error::*; -pub use helpers::{BatchResponseBuilder, BoundedWriter, MethodResponse, MethodSink}; +pub use helpers::*; +pub use method_response::*; pub use rpc_module::*; pub use subscription::*; -use jsonrpsee_types::{ErrorObjectOwned, ResponsePayload}; +use jsonrpsee_types::ErrorObjectOwned; const LOG_TARGET: &str = "jsonrpsee-server"; @@ -64,8 +67,8 @@ where fn into_response(self) -> ResponsePayload<'static, Self::Output> { match self { - Ok(val) => ResponsePayload::result(val), - Err(e) => ResponsePayload::Error(e.into()), + Ok(val) => ResponsePayload::success(val), + Err(e) => ResponsePayload::error(e), } } } @@ -77,7 +80,7 @@ where type Output = Option; fn into_response(self) -> ResponsePayload<'static, Self::Output> { - ResponsePayload::result(self) + ResponsePayload::success(self) } } @@ -88,7 +91,7 @@ where type Output = Vec; fn into_response(self) -> ResponsePayload<'static, Self::Output> { - ResponsePayload::result(self) + ResponsePayload::success(self) } } @@ -99,7 +102,18 @@ where type Output = [T; N]; fn into_response(self) -> ResponsePayload<'static, Self::Output> { - ResponsePayload::result(self) + ResponsePayload::success(self) + } +} + +impl IntoResponse for jsonrpsee_types::ResponsePayload<'static, T> +where + T: serde::Serialize + Clone, +{ + type Output = T; + + fn into_response(self) -> ResponsePayload<'static, Self::Output> { + self.into() } } @@ -115,10 +129,10 @@ where } impl IntoResponse for ErrorObjectOwned { - type Output = ErrorObjectOwned; + type Output = (); fn into_response(self) -> ResponsePayload<'static, Self::Output> { - ResponsePayload::Error(self) + ResponsePayload::error(self) } } @@ -129,7 +143,7 @@ macro_rules! impl_into_response { type Output = $n; fn into_response(self) -> ResponsePayload<'static, Self::Output> { - ResponsePayload::result(self) + ResponsePayload::success(self) } } )+ diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index abceec27d8..676534f2ed 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -32,18 +32,19 @@ use std::sync::Arc; use crate::error::RegisterMethodError; use crate::id_providers::RandomIntegerIdProvider; -use crate::server::LOG_TARGET; -use crate::server::helpers::{MethodResponse, MethodSink}; +use crate::server::helpers::MethodSink; +use crate::server::method_response::MethodResponse; use crate::server::subscription::{ sub_message_to_json, BoundedSubscriptions, IntoSubscriptionCloseResponse, PendingSubscriptionSink, SubNotifResultOrError, Subscribers, Subscription, SubscriptionCloseResponse, SubscriptionKey, SubscriptionPermit, SubscriptionState, }; +use crate::server::{ResponsePayload, LOG_TARGET}; use crate::traits::ToRpcParams; use futures_util::{future::BoxFuture, FutureExt}; use jsonrpsee_types::error::{ErrorCode, ErrorObject}; use jsonrpsee_types::{ - Id, Params, Request, Response, ResponsePayload, ResponseSuccess, SubscriptionId as RpcSubscriptionId, ErrorObjectOwned, + ErrorObjectOwned, Id, Params, Request, Response, ResponseSuccess, SubscriptionId as RpcSubscriptionId, }; use rustc_hash::FxHashMap; use serde::de::DeserializeOwned; @@ -78,7 +79,6 @@ pub type MaxResponseSize = usize; /// - a [`mpsc::UnboundedReceiver`] to receive future subscription results pub type RawRpcResponse = (MethodResponse, mpsc::Receiver); - /// The error that can occur when [`Methods::call`] or [`Methods::subscribe`] is invoked. #[derive(thiserror::Error, Debug)] pub enum MethodsError { @@ -97,12 +97,11 @@ pub enum MethodsError { /// and `Subscribe` calls are handled differently /// because we want to prevent subscriptions to start /// before the actual subscription call has been answered. -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum CallOrSubscription { /// The subscription callback itself sends back the result /// so it must not be sent back again. Subscription(MethodResponse), - /// Treat it as ordinary call. Call(MethodResponse), } @@ -292,7 +291,7 @@ impl Methods { let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!(target: LOG_TARGET, "[Methods::call] Method: {:?}, params: {:?}", method, params); let (resp, _) = self.inner_call(req, 1, mock_subscription_permit()).await; - let rp = serde_json::from_str::>(&resp.result)?; + let rp = serde_json::from_str::>(resp.as_result())?; ResponseSuccess::try_from(rp).map(|s| s.result).map_err(|e| MethodsError::JsonRpc(e.into_owned())) } @@ -319,7 +318,7 @@ impl Methods { /// Ok(()) /// }).unwrap(); /// let (resp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"hi","id":0}"#, 1).await.unwrap(); - /// let resp: Success = serde_json::from_str::>(&resp.result).unwrap().try_into().unwrap(); + /// let resp: Success = serde_json::from_str::>(&resp.as_result()).unwrap().try_into().unwrap(); /// let sub_resp = stream.recv().await.unwrap(); /// assert_eq!( /// format!(r#"{{"jsonrpc":"2.0","method":"hi","params":{{"subscription":{},"result":"one answer"}}}}"#, resp.result), @@ -428,9 +427,9 @@ impl Methods { // TODO: hack around the lifetime on the `SubscriptionId` by deserialize first to serde_json::Value. let as_success: ResponseSuccess = - serde_json::from_str::>(&resp.result)?.try_into()?; + serde_json::from_str::>(resp.as_result())?.try_into()?; - let sub_id = as_success.result.try_into().map_err(|_| MethodsError::InvalidSubscriptionId(resp.result.clone()))?; + let sub_id = as_success.result.try_into().map_err(|_| MethodsError::InvalidSubscriptionId(resp.to_result()))?; Ok(Subscription { sub_id, rx }) } @@ -486,6 +485,15 @@ impl From> for Methods { impl RpcModule { /// Register a new synchronous RPC method, which computes the response with the given callback. + /// + /// ## Examples + /// + /// ``` + /// use jsonrpsee_core::server::RpcModule; + /// + /// let mut module = RpcModule::new(()); + /// module.register_method("say_hello", |_params, _ctx| "lo").unwrap(); + /// ``` pub fn register_method( &mut self, method_name: &'static str, @@ -507,6 +515,17 @@ impl RpcModule { } /// Register a new asynchronous RPC method, which computes the response with the given callback. + /// + /// ## Examples + /// + /// ``` + /// use jsonrpsee_core::server::RpcModule; + /// + /// let mut module = RpcModule::new(()); + /// module.register_async_method("say_hello", |_params, _ctx| async { "lo" }).unwrap(); + /// + /// ``` + /// pub fn register_async_method( &mut self, method_name: &'static str, @@ -884,7 +903,7 @@ impl RpcModule { id ); - return MethodResponse::response(id, ResponsePayload::result(false), max_response_size); + return MethodResponse::response(id, ResponsePayload::success(false), max_response_size); } }; @@ -900,7 +919,7 @@ impl RpcModule { ); } - MethodResponse::response(id, ResponsePayload::result(result), max_response_size) + MethodResponse::response(id, ResponsePayload::success(result), max_response_size) })), ); } diff --git a/core/src/server/subscription.rs b/core/src/server/subscription.rs index d4709ebf74..a71fe9fb79 100644 --- a/core/src/server/subscription.rs +++ b/core/src/server/subscription.rs @@ -26,16 +26,14 @@ //! Subscription related types and traits for server implementations. -use super::MethodsError; -use super::helpers::{MethodResponse, MethodSink}; -use crate::server::LOG_TARGET; +use super::helpers::MethodSink; +use super::{MethodResponse, MethodsError, ResponsePayload}; use crate::server::error::{DisconnectError, PendingSubscriptionAcceptError, SendTimeoutError, TrySendError}; use crate::server::rpc_module::ConnectionId; -use crate::{traits::IdProvider, error::StringError}; +use crate::server::LOG_TARGET; +use crate::{error::StringError, traits::IdProvider}; use jsonrpsee_types::SubscriptionPayload; -use jsonrpsee_types::{ - response::SubscriptionError, ErrorObjectOwned, Id, ResponsePayload, SubscriptionId, SubscriptionResponse, -}; +use jsonrpsee_types::{response::SubscriptionError, ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse}; use parking_lot::Mutex; use rustc_hash::FxHashMap; use serde::{de::DeserializeOwned, Serialize}; @@ -261,7 +259,7 @@ impl PendingSubscriptionSink { /// once reject has been called. pub async fn reject(self, err: impl Into) { let err = MethodResponse::subscription_error(self.id, err.into()); - _ = self.inner.send(err.result.clone()).await; + _ = self.inner.send(err.to_result()).await; _ = self.subscribe.send(err); } @@ -273,7 +271,7 @@ impl PendingSubscriptionSink { pub async fn accept(self) -> Result { let response = MethodResponse::subscription_response( self.id, - ResponsePayload::result_borrowed(&self.uniq_sub.sub_id), + ResponsePayload::success_borrowed(&self.uniq_sub.sub_id), self.inner.max_response_size() as usize, ); let success = response.is_success(); @@ -284,7 +282,7 @@ impl PendingSubscriptionSink { // // The same message is sent twice here because one is sent directly to the transport layer and // the other one is sent internally to accept the subscription. - self.inner.send(response.result.clone()).await.map_err(|_| PendingSubscriptionAcceptError)?; + self.inner.send(response.to_result()).await.map_err(|_| PendingSubscriptionAcceptError)?; self.subscribe.send(response).map_err(|_| PendingSubscriptionAcceptError)?; if success { diff --git a/examples/examples/response_payload_notify_on_response.rs b/examples/examples/response_payload_notify_on_response.rs new file mode 100644 index 0000000000..6662a5ed99 --- /dev/null +++ b/examples/examples/response_payload_notify_on_response.rs @@ -0,0 +1,84 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::net::SocketAddr; + +use jsonrpsee::core::client::ClientT; +use jsonrpsee::proc_macros::rpc; +use jsonrpsee::server::Server; +use jsonrpsee::ws_client::WsClientBuilder; +use jsonrpsee::{rpc_params, ResponsePayload}; + +#[rpc(client, server, namespace = "state")] +pub trait Rpc { + /// Async method call example. + #[method(name = "getKeys")] + fn storage_keys(&self) -> ResponsePayload<'static, String>; +} + +pub struct RpcServerImpl; + +impl RpcServer for RpcServerImpl { + fn storage_keys(&self) -> ResponsePayload<'static, String> { + let (rp, rp_future) = ResponsePayload::success("ehheeheh".to_string()).notify_on_completion(); + + tokio::spawn(async move { + rp_future.await.unwrap(); + println!("Method response to `state_getKeys` finished"); + }); + + rp + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + let server_addr = run_server().await?; + let url = format!("ws://{}", server_addr); + + let client = WsClientBuilder::default().build(&url).await?; + assert_eq!("ehheeheh".to_string(), client.request::("state_getKeys", rpc_params![]).await.unwrap()); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + let server = Server::builder().build("127.0.0.1:0").await?; + + let addr = server.local_addr()?; + let handle = server.start(RpcServerImpl.into_rpc()); + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} diff --git a/proc-macros/src/render_client.rs b/proc-macros/src/render_client.rs index 4f2f6c39d5..f331aad76b 100644 --- a/proc-macros/src/render_client.rs +++ b/proc-macros/src/render_client.rs @@ -109,10 +109,21 @@ impl RpcDescription { let ret_ty = args.last_mut().unwrap(); let err_ty = self.jrps_client_item(quote! { core::client::Error }); + quote! { core::result::Result<#ret_ty, #err_ty> } + } else if type_name.ident == "ResponsePayload" { + // ResponsePayload<'a, T> + if args.len() != 2 { + return quote_spanned!(args.span() => compile_error!("ResponsePayload must have exactly two arguments")); + } + + // The type alias `RpcResult` is modified to `Result`. + let ret_ty = args.last_mut().unwrap(); + let err_ty = self.jrps_client_item(quote! { core::client::Error }); + quote! { core::result::Result<#ret_ty, #err_ty> } } else { // Any other type name isn't allowed. - quote_spanned!(type_name.span() => compile_error!("The return type must be Result or RpcResult")) + quote_spanned!(type_name.span() => compile_error!("The return type must be Result, RpcResult or ResponsePayload<'static, T>")) } } diff --git a/proc-macros/src/render_server.rs b/proc-macros/src/render_server.rs index bb3f60900f..eddb9e3d99 100644 --- a/proc-macros/src/render_server.rs +++ b/proc-macros/src/render_server.rs @@ -299,7 +299,7 @@ impl RpcDescription { let params_fields = quote! { #(#params_fields_seq),* }; let tracing = self.jrps_server_item(quote! { tracing }); let sub_err = self.jrps_server_item(quote! { SubscriptionCloseResponse }); - let response_payload = self.jrps_server_item(quote! { types::ResponsePayload }); + let response_payload = self.jrps_server_item(quote! { ResponsePayload }); let tokio = self.jrps_server_item(quote! { tokio }); // Code to decode sequence of parameters from a JSON array. @@ -323,7 +323,7 @@ impl RpcDescription { Ok(v) => v, Err(e) => { #tracing::debug!(concat!("Error parsing optional \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); - return #response_payload::Error(e); + return #response_payload::error(e); } }; } @@ -346,7 +346,7 @@ impl RpcDescription { Ok(v) => v, Err(e) => { #tracing::debug!(concat!("Error parsing \"", stringify!(#name), "\" as \"", stringify!(#ty), "\": {:?}"), e); - return #response_payload::Error(e); + return #response_payload::error(e); } }; } @@ -431,7 +431,7 @@ impl RpcDescription { Ok(p) => p, Err(e) => { #tracing::debug!("Failed to parse JSON-RPC params as object: {}", e); - return #response_payload::Error(e); + return #response_payload::error(e); } }; diff --git a/proc-macros/tests/ui/correct/custom_ret_types.rs b/proc-macros/tests/ui/correct/custom_ret_types.rs index 8f89b55392..5a70c2f2b8 100644 --- a/proc-macros/tests/ui/correct/custom_ret_types.rs +++ b/proc-macros/tests/ui/correct/custom_ret_types.rs @@ -4,8 +4,7 @@ use std::net::SocketAddr; use jsonrpsee::core::{async_trait, ClientError, Serialize}; use jsonrpsee::proc_macros::rpc; -use jsonrpsee::server::{IntoResponse, ServerBuilder}; -use jsonrpsee::types::ResponsePayload; +use jsonrpsee::server::{IntoResponse, ResponsePayload, ServerBuilder}; use jsonrpsee::ws_client::*; // Serialize impl is not used as the responses are sent out as error. @@ -31,7 +30,7 @@ impl IntoResponse for CustomError { let data = data.map(|val| serde_json::value::to_raw_value(&val).unwrap()); let error_object = jsonrpsee::types::ErrorObjectOwned::owned(code, "custom_error", data); - ResponsePayload::Error(error_object) + ResponsePayload::error(error_object) } } diff --git a/proc-macros/tests/ui/incorrect/rpc/rpc_empty_bounds.stderr b/proc-macros/tests/ui/incorrect/rpc/rpc_empty_bounds.stderr index 050e6ec443..853ee29c4e 100644 --- a/proc-macros/tests/ui/incorrect/rpc/rpc_empty_bounds.stderr +++ b/proc-macros/tests/ui/incorrect/rpc/rpc_empty_bounds.stderr @@ -4,7 +4,7 @@ error[E0277]: the trait bound `::Hash: Serialize` is not satisfi 9 | #[rpc(server, client, namespace = "foo", client_bounds(), server_bounds())] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Serialize` is not implemented for `::Hash` | - = note: required for `std::result::Result<::Hash, ErrorObject<'_>>` to implement `IntoResponse` + = note: required for `Result<::Hash, ErrorObject<'_>>` to implement `IntoResponse` = note: this error originates in the attribute macro `rpc` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `::Hash: Clone` is not satisfied @@ -13,7 +13,7 @@ error[E0277]: the trait bound `::Hash: Clone` is not satisfied 9 | #[rpc(server, client, namespace = "foo", client_bounds(), server_bounds())] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Clone` is not implemented for `::Hash` | - = note: required for `std::result::Result<::Hash, ErrorObject<'_>>` to implement `IntoResponse` + = note: required for `Result<::Hash, ErrorObject<'_>>` to implement `IntoResponse` = note: this error originates in the attribute macro `rpc` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `for<'de> ::Hash: Deserialize<'de>` is not satisfied diff --git a/server/src/middleware/rpc/layer/logger.rs b/server/src/middleware/rpc/layer/logger.rs index b352134a3e..9849bce057 100644 --- a/server/src/middleware/rpc/layer/logger.rs +++ b/server/src/middleware/rpc/layer/logger.rs @@ -105,7 +105,7 @@ impl> Future for ResponseFuture { let res = fut.poll(cx); if let Poll::Ready(rp) = &res { - tx_log_from_str(&rp.result, max); + tx_log_from_str(rp.as_result(), max); } res } diff --git a/server/src/server.rs b/server/src/server.rs index 2adb48b23a..0f1b4d76d6 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -45,7 +45,7 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::helpers::{prepare_error, MethodResponseResult}; +use jsonrpsee_core::server::helpers::prepare_error; use jsonrpsee_core::server::{BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink, Methods}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{JsonRawValue, TEN_MB_SIZE_BYTES}; @@ -617,18 +617,18 @@ impl Builder { /// impl<'a, S> RpcServiceT<'a> for MyMiddleware /// where S: RpcServiceT<'a> + Send + Sync + Clone + 'static, /// { - /// type Future = BoxFuture<'a, MethodResponse>; - /// + /// type Future = BoxFuture<'a, MethodResponse>; + /// /// fn call(&self, req: Request<'a>) -> Self::Future { /// tracing::info!("MyMiddleware processed call {}", req.method); /// let count = self.count.clone(); - /// let service = self.service.clone(); + /// let service = self.service.clone(); /// /// Box::pin(async move { /// let rp = service.call(req).await; /// // Modify the state. /// count.fetch_add(1, Ordering::Relaxed); - /// rp + /// rp /// }) /// } /// } @@ -1309,8 +1309,8 @@ where if got_notif && batch_response.is_empty() { None } else { - let result = batch_response.finish(); - Some(MethodResponse { result, success_or_error: MethodResponseResult::Success, is_subscription: false }) + let batch_rp = batch_response.finish(); + Some(MethodResponse::from_batch(batch_rp)) } } else { Some(MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError))) diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index a475d55c70..57599db4cb 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -92,7 +92,7 @@ where // If the response is empty it means that it was a notification or empty batch. // For HTTP these are just ACK:ed with a empty body. - response::ok_response(rp.map_or(String::new(), |r| r.result)) + response::ok_response(rp.map_or(String::new(), |r| r.into_result())) } // Error scenarios: Method::POST => response::unsupported_content_type(), @@ -110,7 +110,7 @@ pub mod response { /// Create a response for json internal error. pub fn internal_error() -> hyper::Response { - let err = ResponsePayload::error(ErrorObjectOwned::from(ErrorCode::InternalError)); + let err = ResponsePayload::<()>::error(ErrorObjectOwned::from(ErrorCode::InternalError)); let rp = Response::new(err, Id::Null); let error = serde_json::to_string(&rp).expect("built from known-good data; qed"); @@ -133,7 +133,7 @@ pub mod response { /// Create a json response for oversized requests (413) pub fn too_large(limit: u32) -> hyper::Response { - let err = ResponsePayload::error(reject_too_big_request(limit)); + let err = ResponsePayload::<()>::error(reject_too_big_request(limit)); let rp = Response::new(err, Id::Null); let error = serde_json::to_string(&rp).expect("JSON serialization infallible; qed"); @@ -142,7 +142,7 @@ pub mod response { /// Create a json response for empty or malformed requests (400) pub fn malformed() -> hyper::Response { - let rp = Response::new(ErrorCode::ParseError.into(), Id::Null); + let rp = Response::new(ResponsePayload::<()>::error(ErrorCode::ParseError), Id::Null); let error = serde_json::to_string(&rp).expect("JSON serialization infallible; qed"); from_template(hyper::StatusCode::BAD_REQUEST, error, JSON) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 2480529847..11f3bef36c 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -10,8 +10,7 @@ use futures_util::future::{self, Either}; use futures_util::io::{BufReader, BufWriter}; use futures_util::{Future, StreamExt, TryStreamExt}; use hyper::upgrade::Upgraded; -use jsonrpsee_core::server::helpers::MethodSink; -use jsonrpsee_core::server::{BoundedSubscriptions, Methods}; +use jsonrpsee_core::server::{BoundedSubscriptions, MethodSink, Methods}; use jsonrpsee_types::error::{reject_too_big_request, ErrorCode}; use jsonrpsee_types::Id; use soketto::connection::Error as SokettoError; @@ -155,8 +154,20 @@ where handle_rpc_call(&data[idx..], is_single, batch_requests_config, max_response_body_size, &*rpc_service) .await { - if !rp.is_subscription { - _ = sink.send(rp.result).await; + if !rp.is_subscription() { + let is_success = rp.is_success(); + let (serialized_rp, mut on_close) = rp.into_parts(); + + // The connection is closed, just quit. + if sink.send(serialized_rp).await.is_err() { + return; + } + + // Notify that the message has been sent out to the internal + // WebSocket buffer. + if let Some(n) = on_close.take() { + n.notify(is_success); + } } } }); diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 75d14fac0a..8a87cd1e7d 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -25,4 +25,4 @@ tower = { version = "0.4.13", features = ["full"] } tower-http = { version = "0.4.0", features = ["full"] } tracing = "0.1.34" tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } -pin-project = { version = "1" } +pin-project = "1" diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 71765b9c76..94548d53c2 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -37,10 +37,11 @@ use jsonrpsee::server::middleware::http::ProxyGetRequestLayer; use jsonrpsee::server::{ PendingSubscriptionSink, RpcModule, Server, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, }; -use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; -use jsonrpsee::SubscriptionCloseResponse; -use serde::Serialize; +use jsonrpsee::types::{ErrorCode, ErrorObject, ErrorObjectOwned}; +use jsonrpsee::{MethodResponseError, ResponsePayload, SubscriptionCloseResponse}; +use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; +use tokio::sync::mpsc::UnboundedSender; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; use tower_http::cors::CorsLayer; @@ -287,3 +288,56 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream< .await .unwrap() } + +#[derive(Copy, Clone, Deserialize, Serialize)] +pub enum Notify { + Success, + Error, + All, +} + +pub type NotifyRpcModule = RpcModule>>; +pub type Sender = tokio::sync::mpsc::UnboundedSender>; +pub type Receiver = tokio::sync::mpsc::UnboundedReceiver>; + +pub async fn run_test_notify_test( + module: &NotifyRpcModule, + server_rx: &mut Receiver, + is_success: bool, + kind: Notify, +) -> Result<(), MethodResponseError> { + use jsonrpsee_test_utils::mocks::Id; + + let req = jsonrpsee_test_utils::helpers::call("hey", vec![kind], Id::Num(1)); + let (rp, _) = module.raw_json_request(&req, 1).await.unwrap(); + let (_, notify_rx) = rp.into_parts(); + notify_rx.unwrap().notify(is_success); + server_rx.recv().await.expect("Channel is not dropped") +} + +/// Helper module that will send the results on the channel passed in. +pub fn rpc_module_notify_on_response(tx: Sender) -> NotifyRpcModule { + let mut module = RpcModule::new(tx); + + module + .register_method("hey", |params, ctx| { + let kind: Notify = params.one().unwrap(); + let server_sender = ctx.clone(); + + let (rp, rp_future) = match kind { + Notify::All => ResponsePayload::success("lo").notify_on_completion(), + Notify::Success => ResponsePayload::success("lo").notify_on_completion(), + Notify::Error => ResponsePayload::error(ErrorCode::InvalidParams).notify_on_completion(), + }; + + tokio::spawn(async move { + let rp = rp_future.await; + server_sender.send(rp).unwrap(); + }); + + rp + }) + .unwrap(); + + module +} diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index c58c0538ab..b6991b6196 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -43,13 +43,13 @@ use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, Error, IdKind, Subscription, SubscriptionClientT}; use jsonrpsee::core::params::{ArrayParams, BatchRequestBuilder}; use jsonrpsee::core::server::SubscriptionMessage; -use jsonrpsee::core::JsonValue; +use jsonrpsee::core::{JsonValue, StringError}; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::server::middleware::http::HostFilterLayer; use jsonrpsee::server::{ServerBuilder, ServerHandle}; use jsonrpsee::types::error::{ErrorObject, UNKNOWN_ERROR_CODE}; use jsonrpsee::ws_client::WsClientBuilder; -use jsonrpsee::{rpc_params, RpcModule}; +use jsonrpsee::{rpc_params, ResponsePayload, RpcModule}; use jsonrpsee_test_utils::TimeoutFutureExt; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; @@ -1302,6 +1302,89 @@ async fn run_shutdown_test_inner( assert!(client.request::("sleep_20s", rpc_params!()).await.is_err()); } +#[tokio::test] +async fn response_payload_async_api_works() { + use jsonrpsee::server::{Server, SubscriptionSink}; + use std::sync::Arc; + use tokio::sync::Mutex as AsyncMutex; + + init_logger(); + + let server_addr = { + #[allow(clippy::type_complexity)] + let state: Arc)>>> = Arc::default(); + + let mut module = RpcModule::new(state); + module + .register_method("get", |_params, ctx| { + let ctx = ctx.clone(); + let (rp, rp_future) = ResponsePayload::success(1).notify_on_completion(); + + tokio::spawn(async move { + // Wait for response to sent to the internal WebSocket message buffer + // and if that fails just quit because it means that the connection + // was closed or that method response was an error. + // + // You can identify that by matching on the error. + if rp_future.await.is_err() { + return; + } + + if let Some((sink, close)) = ctx.lock().await.take() { + for idx in 0..3 { + let msg = SubscriptionMessage::from_json(&idx).unwrap(); + _ = sink.send(msg).await; + } + drop(sink); + drop(close); + } + }); + + rp + }) + .unwrap(); + + module + .register_subscription::, _, _>("sub", "s", "unsub", |_, pending, ctx| async move { + let sink = pending.accept().await?; + let (tx, rx) = tokio::sync::oneshot::channel(); + *ctx.lock().await = Some((sink, tx)); + let _ = rx.await; + Err("Dropped".into()) + }) + .unwrap(); + + let server = Server::builder().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); + let addr = server.local_addr().unwrap(); + + let handle = server.start(module); + + tokio::spawn(handle.stopped()); + + format!("ws://{addr}") + }; + + let client = jsonrpsee::ws_client::WsClientBuilder::default() + .build(&server_addr) + .with_default_timeout() + .await + .unwrap() + .unwrap(); + + // Make a subscription which is stored as state in the sequent rpc call "get". + let sub = + client.subscribe::("sub", rpc_params!(), "unsub").with_default_timeout().await.unwrap().unwrap(); + + // assert that method call was answered + // and a few notification were sent by + // the spawned the task. + // + // ideally, that ordering should also be tested here + // but not possible to test properly. + assert!(client.request::("get", rpc_params!()).await.is_ok()); + assert_eq!(sub.count().await, 3); +} + /// Run shutdown test and it does: /// /// - Make 10 calls that sleeps for 20 seconds diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index 721b010e2f..a88d1f2e85 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -266,7 +266,7 @@ async fn macro_optional_param_parsing() { .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_optional_params","params":{"a":22,"c":50},"id":0}"#, 1) .await .unwrap(); - assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":"Called with: 22, None, Some(50)","id":0}"#); + assert_eq!(resp.into_result(), r#"{"jsonrpc":"2.0","result":"Called with: 22, None, Some(50)","id":0}"#); } #[tokio::test] @@ -290,14 +290,14 @@ async fn macro_zero_copy_cow() { .unwrap(); // std::borrow::Cow always deserialized to owned variant here - assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":"Zero copy params: false, true","id":0}"#); + assert_eq!(resp.into_result(), r#"{"jsonrpc":"2.0","result":"Zero copy params: false, true","id":0}"#); // serde_json will have to allocate a new string to replace `\t` with byte 0x09 (tab) let (resp, _) = module .raw_json_request(r#"{"jsonrpc":"2.0","method":"foo_zero_copy_cow","params":["\tfoo", "\tbar"],"id":0}"#, 1) .await .unwrap(); - assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":"Zero copy params: false, false","id":0}"#); + assert_eq!(resp.into_result(), r#"{"jsonrpc":"2.0","result":"Zero copy params: false, false","id":0}"#); } // Disabled on MacOS as GH CI timings on Mac vary wildly (~100ms) making this test fail. diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index 59752aaa7c..9e9e081271 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -38,11 +38,14 @@ use jsonrpsee::core::{server::*, RpcResult}; use jsonrpsee::types::error::{ErrorCode, ErrorObject, INVALID_PARAMS_MSG, PARSE_ERROR_CODE}; use jsonrpsee::types::{ErrorObjectOwned, Params, Response, ResponsePayload}; use jsonrpsee::SubscriptionMessage; +use jsonrpsee_test_utils::mocks::Id; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; +use crate::helpers::{rpc_module_notify_on_response, run_test_notify_test, Notify}; + // Helper macro to assert that a binding is of a specific type. macro_rules! assert_type { ( $ty:ty, $expected:expr $(,)?) => {{ @@ -384,13 +387,13 @@ async fn subscribe_unsubscribe_without_server() { let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); let (resp, _) = module.raw_json_request(&unsub_req, 1).await.unwrap(); - assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":true,"id":1}"#); + assert_eq!(resp.into_result(), r#"{"jsonrpc":"2.0","result":true,"id":1}"#); // Unsubscribe already performed; should be error. let unsub_req = format!("{{\"jsonrpc\":\"2.0\",\"method\":\"my_unsub\",\"params\":[{}],\"id\":1}}", ser_id); let (resp, _) = module.raw_json_request(&unsub_req, 2).await.unwrap(); - assert_eq!(resp.result, r#"{"jsonrpc":"2.0","result":false,"id":1}"#); + assert_eq!(resp.into_result(), r#"{"jsonrpc":"2.0","result":false,"id":1}"#); } let sub1 = subscribe_and_assert(&module); @@ -430,7 +433,7 @@ async fn reject_works() { .unwrap(); let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","id":0}"#, 1).await.unwrap(); - assert_eq!(rp.result, r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"rejected"},"id":0}"#); + assert_eq!(rp.into_result(), r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"rejected"},"id":0}"#); assert!(stream.recv().await.is_none()); } @@ -521,11 +524,11 @@ async fn serialize_sub_error_adds_extra_string_quotes() { .unwrap(); let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","id":0}"#, 1).await.unwrap(); - let resp = serde_json::from_str::>(&rp.result).unwrap(); + let resp = serde_json::from_str::>(rp.as_result()).unwrap(); let sub_resp = stream.recv().await.unwrap(); let resp = match resp.payload { - ResponsePayload::Result(val) => val, + ResponsePayload::Success(val) => val, _ => panic!("Expected valid response"), }; @@ -566,10 +569,10 @@ async fn subscription_close_response_works() { { let (rp, mut stream) = module.raw_json_request(r#"{"jsonrpc":"2.0","method":"my_sub","params":[1],"id":0}"#, 1).await.unwrap(); - let resp = serde_json::from_str::>(&rp.result).unwrap(); + let resp = serde_json::from_str::>(rp.as_result()).unwrap(); let sub_id = match resp.payload { - ResponsePayload::Result(val) => val, + ResponsePayload::Success(val) => val, _ => panic!("Expected valid response"), }; @@ -586,3 +589,42 @@ async fn subscription_close_response_works() { assert_eq!(rx, 1); } } + +#[tokio::test] +async fn method_response_notify_on_completion() { + init_logger(); + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let module = rpc_module_notify_on_response(tx); + + assert!( + run_test_notify_test(&module, &mut rx, true, Notify::Success).await.is_ok(), + "Successful response should be notified" + ); + assert!(matches!( + run_test_notify_test(&module, &mut rx, false, Notify::Success).await, + Err(MethodResponseError::JsonRpcError), + )); + + assert!(matches!( + run_test_notify_test(&module, &mut rx, false, Notify::Error).await, + Err(MethodResponseError::JsonRpcError), + )); +} + +#[tokio::test] +async fn method_response_dropped() { + init_logger(); + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let module = rpc_module_notify_on_response(tx); + + let req = jsonrpsee_test_utils::helpers::call("hey", vec![Notify::Success], Id::Num(1)); + + // Make a call and drop the method response including its "notify sender" + // This could happen if the connection is closed. + let (rp, _) = module.raw_json_request(&req, 1).await.unwrap(); + drop(rp); + + assert!(matches!(rx.recv().await, Some(Err(MethodResponseError::Closed)))); +} diff --git a/types/src/lib.rs b/types/src/lib.rs index 758815ecbe..288c3fbbee 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -43,7 +43,7 @@ pub mod response; /// JSON-RPC response error object related types. pub mod error; -pub use error::{ErrorObject, ErrorObjectOwned}; +pub use error::{ErrorCode, ErrorObject, ErrorObjectOwned}; pub use params::{Id, InvalidRequestId, Params, ParamsSequence, SubscriptionId, TwoPointZero}; pub use request::{InvalidRequest, Notification, NotificationSer, Request, RequestSer}; pub use response::{Response, ResponsePayload, SubscriptionPayload, SubscriptionResponse, Success as ResponseSuccess}; diff --git a/types/src/response.rs b/types/src/response.rs index 3f8a53e6e1..9522231cd0 100644 --- a/types/src/response.rs +++ b/types/src/response.rs @@ -95,7 +95,7 @@ impl<'a, T: Clone> TryFrom> for Success<'a, T> { fn try_from(rp: Response<'a, T>) -> Result { match rp.payload { ResponsePayload::Error(e) => Err(e.into_owned()), - ResponsePayload::Result(r) => Ok(Success { jsonrpc: rp.jsonrpc, result: r.into_owned(), id: rp.id }), + ResponsePayload::Success(r) => Ok(Success { jsonrpc: rp.jsonrpc, result: r.into_owned(), id: rp.id }), } } } @@ -139,45 +139,43 @@ where T: Clone, { /// Corresponds to successful JSON-RPC response with the field `result`. - Result(StdCow<'a, T>), + Success(StdCow<'a, T>), /// Corresponds to failed JSON-RPC response with a error object with the field `error. Error(ErrorObject<'a>), } impl<'a, T: Clone> ResponsePayload<'a, T> { - /// Create successful an owned response payload. - pub fn result(t: T) -> Self { - Self::Result(StdCow::Owned(t)) + /// Create a successful owned response payload. + pub fn success(t: T) -> Self { + Self::Success(StdCow::Owned(t)) } - /// Create successful borrowed response payload. - pub fn result_borrowed(t: &'a T) -> Self { - Self::Result(StdCow::Borrowed(t)) + /// Create a successful borrowed response payload. + pub fn success_borrowed(t: &'a T) -> Self { + Self::Success(StdCow::Borrowed(t)) } /// Convert the response payload into owned. pub fn into_owned(self) -> ResponsePayload<'static, T> { match self { Self::Error(e) => ResponsePayload::Error(e.into_owned()), - Self::Result(r) => ResponsePayload::Result(StdCow::Owned(r.into_owned())), + Self::Success(r) => ResponsePayload::Success(StdCow::Owned(r.into_owned())), } } -} -impl<'a> ResponsePayload<'a, ()> { - /// Create successful partial response i.e, the `result field` + /// Create an error response payload. pub fn error(e: impl Into) -> Self { Self::Error(e.into()) } - /// Create successful partial response i.e, the `result field` + /// Create a borrowd error response payload. pub fn error_borrowed(e: impl Into>) -> Self { Self::Error(e.into()) } } -impl<'a> From for ResponsePayload<'a, ()> { - fn from(code: ErrorCode) -> Self { +impl<'a, T: Clone> From for ResponsePayload<'a, T> { + fn from(code: ErrorCode) -> ResponsePayload<'a, T> { Self::Error(code.into()) } } @@ -291,10 +289,12 @@ where return Err(serde::de::Error::duplicate_field("result and error are mutually exclusive")) } (Some(jsonrpc), Some(result), None) => { - Response { jsonrpc, payload: ResponsePayload::Result(result), id } + Response { jsonrpc, payload: ResponsePayload::Success(result), id } } (Some(jsonrpc), None, Some(err)) => Response { jsonrpc, payload: ResponsePayload::Error(err), id }, - (None, Some(result), _) => Response { jsonrpc: None, payload: ResponsePayload::Result(result), id }, + (None, Some(result), _) => { + Response { jsonrpc: None, payload: ResponsePayload::Success(result), id } + } (None, _, Some(err)) => Response { jsonrpc: None, payload: ResponsePayload::Error(err), id }, (_, None, None) => return Err(serde::de::Error::missing_field("result/error")), }; @@ -324,7 +324,7 @@ where match &self.payload { ResponsePayload::Error(err) => s.serialize_field("error", err)?, - ResponsePayload::Result(r) => s.serialize_field("result", r)?, + ResponsePayload::Success(r) => s.serialize_field("result", r)?, }; s.serialize_field("id", &self.id)?; @@ -341,7 +341,7 @@ mod tests { fn serialize_call_ok_response() { let ser = serde_json::to_string(&Response { jsonrpc: Some(TwoPointZero), - payload: ResponsePayload::result("ok"), + payload: ResponsePayload::success("ok"), id: Id::Number(1), }) .unwrap(); @@ -353,7 +353,7 @@ mod tests { fn serialize_call_err_response() { let ser = serde_json::to_string(&Response { jsonrpc: Some(TwoPointZero), - payload: ResponsePayload::error(ErrorObjectOwned::owned(1, "lo", None::<()>)), + payload: ResponsePayload::<()>::error(ErrorObjectOwned::owned(1, "lo", None::<()>)), id: Id::Number(1), }) .unwrap(); @@ -365,7 +365,7 @@ mod tests { fn serialize_call_response_missing_version_field() { let ser = serde_json::to_string(&Response { jsonrpc: None, - payload: ResponsePayload::result("ok"), + payload: ResponsePayload::success("ok"), id: Id::Number(1), }) .unwrap(); @@ -376,7 +376,7 @@ mod tests { #[test] fn deserialize_success_call() { let exp = - Response { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::result(99_u64), id: Id::Number(11) }; + Response { jsonrpc: Some(TwoPointZero), payload: ResponsePayload::success(99_u64), id: Id::Number(11) }; let dsr: Response = serde_json::from_str(r#"{"jsonrpc":"2.0", "result":99, "id":11}"#).unwrap(); assert_eq!(dsr.jsonrpc, exp.jsonrpc); assert_eq!(dsr.payload, exp.payload); @@ -399,7 +399,7 @@ mod tests { #[test] fn deserialize_call_missing_version_field() { - let exp = Response { jsonrpc: None, payload: ResponsePayload::result(99_u64), id: Id::Number(11) }; + let exp = Response { jsonrpc: None, payload: ResponsePayload::success(99_u64), id: Id::Number(11) }; let dsr: Response = serde_json::from_str(r#"{"jsonrpc":null, "result":99, "id":11}"#).unwrap(); assert_eq!(dsr.jsonrpc, exp.jsonrpc); assert_eq!(dsr.payload, exp.payload);