Skip to content

Commit

Permalink
net: Allow specifying custom interest for TcpStream.
Browse files Browse the repository at this point in the history
With the addition of Interest::PRIORITY, it is possible to ask `ready()`
to wake on urgent data. However, because `PollEvented` is set up with
only read/write interest, `ready()`'s future will never complete.

Add support for specifying custom interest in the creation of
`PollEvented` to allow `ready()` to correctly poll for urgent.

Fixes: #5784
  • Loading branch information
leftmostcat committed Jun 15, 2023
1 parent 00af6ef commit 6d5a255
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 7 deletions.
7 changes: 7 additions & 0 deletions tokio/src/io/interest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ impl Interest {
///
/// assert!(BOTH.is_readable());
/// assert!(BOTH.is_writable());
/// ```
pub const fn add(self, other: Interest) -> Interest {
Interest(self.0.add(other.0))
}
Expand All @@ -135,6 +136,12 @@ impl Interest {
}
}

impl Default for Interest {
fn default() -> Self {
Interest::READABLE.add(Interest::WRITABLE)
}
}

impl ops::BitOr for Interest {
type Output = Self;

Expand Down
47 changes: 46 additions & 1 deletion tokio/src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,58 @@ impl TcpListener {
/// }
/// ```
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
self.accept_with_interest(Default::default()).await
}

/// Accepts a new incoming connection from this listener with custom
/// interest registration.
///
/// This function will yield once a new TCP connection is established. When
/// established, the corresponding [`TcpStream`] and the remote peer's
/// address will be returned.
///
/// # Cancel safety
///
/// This method is cancel safe. If the method is used as the event in a
/// [`tokio::select!`](crate::select) statement and some other branch
/// completes first, then it is guaranteed that no new connections were
/// accepted by this method.
///
/// [`TcpStream`]: struct@crate::net::TcpStream
///
/// # Examples
///
/// ```no_run
/// use tokio::{io::Interest, net::TcpListener};
///
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener
/// .accept_with_interest(Interest::PRIORITY.add(Default::default()))
/// .await
/// {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
/// Err(e) => println!("couldn't get client: {:?}", e),
/// }
///
/// Ok(())
/// }
/// ```
pub async fn accept_with_interest(
&self,
interest: Interest,
) -> io::Result<(TcpStream, SocketAddr)> {
let (mio, addr) = self
.io
.registration()
.async_io(Interest::READABLE, || self.io.accept())
.await?;

let stream = TcpStream::new(mio)?;
let stream = TcpStream::new_with_interest(mio, interest)?;
Ok((stream, addr))
}

Expand Down
2 changes: 1 addition & 1 deletion tokio/src/net/tcp/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ impl TcpSocket {
unsafe { mio::net::TcpStream::from_raw_socket(raw_socket) }
};

TcpStream::connect_mio(mio).await
TcpStream::connect_mio(mio, Default::default()).await
}

/// Converts the socket into a `TcpListener`.
Expand Down
65 changes: 60 additions & 5 deletions tokio/src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,59 @@ impl TcpStream {
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
Self::connect_with_interest(addr, Default::default()).await
}

/// Opens a TCP connection to a remote host with custom interest
/// registration..
///
/// `addr` is an address of the remote host. Anything which implements the
/// [`ToSocketAddrs`] trait can be supplied as the address. If `addr`
/// yields multiple addresses, connect will be attempted with each of the
/// addresses until a connection is successful. If none of the addresses
/// result in a successful connection, the error returned from the last
/// connection attempt (the last address) is returned.
///
/// To configure the socket before connecting, you can use the [`TcpSocket`]
/// type.
///
/// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs
/// [`TcpSocket`]: struct@crate::net::TcpSocket
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use tokio::io::{AsyncWriteExt, Interest};
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let mut stream = TcpStream::connect_with_interest(
/// "127.0.0.1:8080",
/// Interest::PRIORITY.add(Default::default()),
/// )
/// .await?;
///
/// // Write some data.
/// stream.write_all(b"hello world!").await?;
///
/// Ok(())
/// }
/// ```
///
/// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait.
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn connect_with_interest<A: ToSocketAddrs>(addr: A, interest: Interest) -> io::Result<TcpStream> {
let addrs = to_socket_addrs(addr).await?;

let mut last_err = None;

for addr in addrs {
match TcpStream::connect_addr(addr).await {
match TcpStream::connect_addr(addr, interest).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
Expand All @@ -132,13 +179,13 @@ impl TcpStream {
}

/// Establishes a connection to the specified `addr`.
async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> {
async fn connect_addr(addr: SocketAddr, interest: Interest) -> io::Result<TcpStream> {
let sys = mio::net::TcpStream::connect(addr)?;
TcpStream::connect_mio(sys).await
TcpStream::connect_mio(sys, interest).await
}

pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> {
let stream = TcpStream::new(sys)?;
pub(crate) async fn connect_mio(sys: mio::net::TcpStream, interest: Interest) -> io::Result<TcpStream> {
let stream = TcpStream::new_with_interest(sys, interest)?;

// Once we've connected, wait for the stream to be writable as
// that's when the actual connection has been initiated. Once we're
Expand All @@ -161,6 +208,14 @@ impl TcpStream {
Ok(TcpStream { io })
}

pub(crate) fn new_with_interest(
connected: mio::net::TcpStream,
interest: Interest,
) -> io::Result<TcpStream> {
let io = PollEvented::new_with_interest(connected, interest)?;
Ok(TcpStream { io })
}

/// Creates new `TcpStream` from a `std::net::TcpStream`.
///
/// This function is intended to be used to wrap a TCP stream from the
Expand Down

0 comments on commit 6d5a255

Please sign in to comment.