diff --git a/client/network/src/protocol/notifications/handler.rs b/client/network/src/protocol/notifications/handler.rs index 57561c7b9879d..ca87941cb96df 100644 --- a/client/network/src/protocol/notifications/handler.rs +++ b/client/network/src/protocol/notifications/handler.rs @@ -782,6 +782,9 @@ impl ConnectionHandler for NotifsHandler { // performed before the code paths that can produce `Ready` (with some rare exceptions). // Importantly, however, the flush is performed *after* notifications are queued with // `Sink::start_send`. + // Note that we must call `poll_flush` on all substreams and not only on those we + // have called `Sink::start_send` on, because `NotificationsOutSubstream::poll_flush` + // also reports the substream termination (even if no data was written into it). for protocol_index in 0..self.protocols.len() { match &mut self.protocols[protocol_index].state { State::Open { out_substream: out_substream @ Some(_), .. } => { @@ -824,7 +827,7 @@ impl ConnectionHandler for NotifsHandler { State::OpenDesiredByRemote { in_substream, pending_opening } => match NotificationsInSubstream::poll_process(Pin::new(in_substream), cx) { Poll::Pending => {}, - Poll::Ready(Ok(void)) => match void {}, + Poll::Ready(Ok(())) => {}, Poll::Ready(Err(_)) => { self.protocols[protocol_index].state = State::Closed { pending_opening: *pending_opening }; @@ -840,7 +843,7 @@ impl ConnectionHandler for NotifsHandler { cx, ) { Poll::Pending => {}, - Poll::Ready(Ok(void)) => match void {}, + Poll::Ready(Ok(())) => {}, Poll::Ready(Err(_)) => *in_substream = None, }, } diff --git a/client/network/src/protocol/notifications/upgrade/notifications.rs b/client/network/src/protocol/notifications/upgrade/notifications.rs index 71afc3c90e37f..3621c63497d95 100644 --- a/client/network/src/protocol/notifications/upgrade/notifications.rs +++ b/client/network/src/protocol/notifications/upgrade/notifications.rs @@ -41,7 +41,6 @@ use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use log::{error, warn}; use sc_network_common::protocol::ProtocolName; use std::{ - convert::Infallible, io, mem, pin::Pin, task::{Context, Poll}, @@ -221,10 +220,7 @@ where /// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is /// guaranteed to not generate any notification. - pub fn poll_process( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { + pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.project(); loop { @@ -246,8 +242,10 @@ where }, NotificationsInSubstreamHandshake::Flush => { match Sink::poll_flush(this.socket.as_mut(), cx)? { - Poll::Ready(()) => - *this.handshake = NotificationsInSubstreamHandshake::Sent, + Poll::Ready(()) => { + *this.handshake = NotificationsInSubstreamHandshake::Sent; + return Poll::Ready(Ok(())) + }, Poll::Pending => { *this.handshake = NotificationsInSubstreamHandshake::Flush; return Poll::Pending @@ -260,7 +258,7 @@ where st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote | st @ NotificationsInSubstreamHandshake::BothSidesClosed => { *this.handshake = st; - return Poll::Pending + return Poll::Ready(Ok(())) }, } } @@ -443,6 +441,21 @@ where fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let mut this = self.project(); + + // `Sink::poll_flush` does not expose stream closed error until we write something into + // the stream, so the code below makes sure we detect that the substream was closed + // even if we don't write anything into it. + match Stream::poll_next(this.socket.as_mut(), cx) { + Poll::Pending => {}, + Poll::Ready(Some(_)) => { + error!( + target: "sub-libp2p", + "Unexpected incoming data in `NotificationsOutSubstream`", + ); + }, + Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Terminated)), + } + Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io) } @@ -492,13 +505,21 @@ pub enum NotificationsOutError { /// I/O error on the substream. #[error(transparent)] Io(#[from] io::Error), + + /// End of incoming data detected on out substream. + #[error("substream was closed/reset")] + Terminated, } #[cfg(test)] mod tests { - use super::{NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutOpen}; - use futures::{channel::oneshot, prelude::*}; + use super::{ + NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutError, + NotificationsOutOpen, + }; + use futures::{channel::oneshot, future, prelude::*}; use libp2p::core::upgrade; + use std::{pin::Pin, task::Poll}; use tokio::net::{TcpListener, TcpStream}; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -691,4 +712,95 @@ mod tests { client.await.unwrap(); } + + #[tokio::test] + async fn send_handshake_without_polling_for_incoming_data() { + const PROTO_NAME: &str = "/test/proto/1"; + let (listener_addr_tx, listener_addr_rx) = oneshot::channel(); + + let client = tokio::spawn(async move { + let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap(); + let NotificationsOutOpen { handshake, .. } = upgrade::apply_outbound( + socket.compat(), + NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024), + upgrade::Version::V1, + ) + .await + .unwrap(); + + assert_eq!(handshake, b"hello world"); + }); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + listener_addr_tx.send(listener.local_addr().unwrap()).unwrap(); + + let (socket, _) = listener.accept().await.unwrap(); + let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound( + socket.compat(), + NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024), + ) + .await + .unwrap(); + + assert_eq!(handshake, b"initial message"); + substream.send_handshake(&b"hello world"[..]); + + // Actually send the handshake. + future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap(); + + client.await.unwrap(); + } + + #[tokio::test] + async fn can_detect_dropped_out_substream_without_writing_data() { + const PROTO_NAME: &str = "/test/proto/1"; + let (listener_addr_tx, listener_addr_rx) = oneshot::channel(); + + let client = tokio::spawn(async move { + let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap(); + let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound( + socket.compat(), + NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024), + upgrade::Version::V1, + ) + .await + .unwrap(); + + assert_eq!(handshake, b"hello world"); + + future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => { + cx.waker().wake_by_ref(); + Poll::Pending + }, + Poll::Ready(Err(e)) => { + assert!(matches!(e, NotificationsOutError::Terminated)); + Poll::Ready(()) + }, + }) + .await; + }); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + listener_addr_tx.send(listener.local_addr().unwrap()).unwrap(); + + let (socket, _) = listener.accept().await.unwrap(); + let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound( + socket.compat(), + NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024), + ) + .await + .unwrap(); + + assert_eq!(handshake, b"initial message"); + + // Send the handhsake. + substream.send_handshake(&b"hello world"[..]); + future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap(); + + drop(substream); + + client.await.unwrap(); + } }