diff --git a/ul/src/association/client.rs b/ul/src/association/client.rs index ffb7af86a..06a41b979 100644 --- a/ul/src/association/client.rs +++ b/ul/src/association/client.rs @@ -8,7 +8,7 @@ use std::{ borrow::Cow, convert::TryInto, io::Write, - net::{TcpStream, ToSocketAddrs}, + net::{TcpStream, ToSocketAddrs}, time::Duration, }; use crate::{ @@ -39,6 +39,18 @@ pub enum Error { source: std::io::Error, backtrace: Backtrace, }, + + /// Could not set tcp read timeout + SetReadTimeout{ + source: std::io::Error, + backtrace: Backtrace, + }, + + /// Could not set tcp write timeout + SetWriteTimeout{ + source: std::io::Error, + backtrace: Backtrace, + }, /// failed to send association request SendRequest { @@ -177,6 +189,10 @@ pub struct ClientAssociationOptions<'a> { saml_assertion: Option>, /// User identity JWT jwt: Option>, + /// TCP read timeout + read_timeout: Option, + /// TCP write timeout + write_timeout: Option, } impl<'a> Default for ClientAssociationOptions<'a> { @@ -198,6 +214,8 @@ impl<'a> Default for ClientAssociationOptions<'a> { kerberos_service_ticket: None, saml_assertion: None, jwt: None, + read_timeout: None, + write_timeout: None, } } } @@ -431,6 +449,22 @@ impl<'a> ClientAssociationOptions<'a> { } } + /// Set the read timeout for the underlying TCP socket + pub fn read_timeout(self, timeout: Duration) -> Self { + Self { + read_timeout: Some(timeout), + ..self + } + } + + /// Set the write timeout for the underlying TCP socket + pub fn write_timeout(self, timeout: Duration) -> Self { + Self { + write_timeout: Some(timeout), + ..self + } + } + fn establish_impl(self, ae_address: AeAddr) -> Result where T: ToSocketAddrs, @@ -448,6 +482,8 @@ impl<'a> ClientAssociationOptions<'a> { kerberos_service_ticket, saml_assertion, jwt, + read_timeout, + write_timeout } = self; // fail if no presentation contexts were provided: they represent intent, @@ -510,7 +546,12 @@ impl<'a> ClientAssociationOptions<'a> { user_variables, }); - let mut socket = std::net::TcpStream::connect(ae_address).context(ConnectSnafu)?; + let mut socket = std::net::TcpStream::connect(ae_address) + .context(ConnectSnafu)?; + socket.set_read_timeout(read_timeout) + .context(SetReadTimeoutSnafu)?; + socket.set_write_timeout(write_timeout) + .context(SetWriteTimeoutSnafu)?; let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); // send request