Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ApplyTcpOptionsError #60

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 35 additions & 51 deletions src/tcp_options.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(target_os = "linux")]
use nix::sys::socket::{getsockopt, setsockopt, sockopt};
use std::fmt;
use std::io;
use std::time::Duration;
use tokio::net::{TcpSocket, TcpStream};

Expand Down Expand Up @@ -37,15 +36,26 @@ pub struct TcpOptions {

/// Represents a failure to apply socket options to the TCP socket.
#[derive(Debug)]
pub struct ApplyTcpOptionsError(ApplyTcpOptionsErrorInternal);
pub struct ApplyTcpOptionsError {
kind: ApplyTcpOptionsErrorKind,
source: Box<dyn std::error::Error>,
}

#[derive(Debug)]
enum ApplyTcpOptionsErrorInternal {
RecvBuffer(io::Error),
SendBuffer(io::Error),
#[cfg(target_os = "linux")]
Mark(nix::Error),
TcpNoDelay(io::Error),
impl ApplyTcpOptionsError {
fn new<S>(kind: ApplyTcpOptionsErrorKind, source: S) -> Self
where
S: Into<Box<dyn std::error::Error>>,
{
Self {
kind,
source: source.into(),
}
}

/// Returns the kind of error that happened as an enum
pub fn kind(&self) -> ApplyTcpOptionsErrorKind {
self.kind
}
}

/// A list specifying what failed when applying the TCP options.
Expand All @@ -66,50 +76,23 @@ pub enum ApplyTcpOptionsErrorKind {
TcpNoDelay,
}

impl ApplyTcpOptionsError {
/// Returns the kind of error that happened as an enum
pub fn kind(&self) -> ApplyTcpOptionsErrorKind {
use ApplyTcpOptionsErrorInternal::*;
match self.0 {
RecvBuffer(_) => ApplyTcpOptionsErrorKind::RecvBuffer,
SendBuffer(_) => ApplyTcpOptionsErrorKind::SendBuffer,
#[cfg(target_os = "linux")]
Mark(_) => ApplyTcpOptionsErrorKind::Mark,
TcpNoDelay(_) => ApplyTcpOptionsErrorKind::TcpNoDelay,
}
}
}

impl From<ApplyTcpOptionsErrorInternal> for ApplyTcpOptionsError {
fn from(value: ApplyTcpOptionsErrorInternal) -> Self {
Self(value)
}
}

impl fmt::Display for ApplyTcpOptionsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use ApplyTcpOptionsErrorInternal::*;
match self.0 {
RecvBuffer(_) => "Failed to get/set TCP_RCVBUF",
SendBuffer(_) => "Failed to get/set TCP_SNDBUF",
use ApplyTcpOptionsErrorKind::*;
match self.kind {
RecvBuffer => "Failed to get/set TCP_RCVBUF",
SendBuffer => "Failed to get/set TCP_SNDBUF",
#[cfg(target_os = "linux")]
Mark(_) => "Failed to get/set SO_MARK",
TcpNoDelay(_) => "Failed to get/set TCP_NODELAY",
Mark => "Failed to get/set SO_MARK",
TcpNoDelay => "Failed to get/set TCP_NODELAY",
}
.fmt(f)
}
}

impl std::error::Error for ApplyTcpOptionsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use ApplyTcpOptionsErrorInternal::*;
match &self.0 {
RecvBuffer(e) => Some(e),
SendBuffer(e) => Some(e),
#[cfg(target_os = "linux")]
Mark(e) => Some(e),
TcpNoDelay(e) => Some(e),
}
Some(self.source.as_ref())
}
}

Expand All @@ -124,34 +107,35 @@ pub fn apply(socket: &TcpSocket, options: &TcpOptions) -> Result<(), ApplyTcpOpt
if let Some(recv_buffer_size) = options.recv_buffer_size {
socket
.set_recv_buffer_size(recv_buffer_size)
.map_err(ApplyTcpOptionsErrorInternal::RecvBuffer)?;
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::RecvBuffer, e))?;
}
log::debug!(
"SO_RCVBUF: {}",
socket
.recv_buffer_size()
.map_err(ApplyTcpOptionsErrorInternal::RecvBuffer)?
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::RecvBuffer, e))?
);
if let Some(send_buffer_size) = options.send_buffer_size {
socket
.set_send_buffer_size(send_buffer_size)
.map_err(ApplyTcpOptionsErrorInternal::SendBuffer)?;
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::SendBuffer, e))?;
}
log::debug!(
"SO_SNDBUF: {}",
socket
.send_buffer_size()
.map_err(ApplyTcpOptionsErrorInternal::SendBuffer)?
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::SendBuffer, e))?
);
#[cfg(target_os = "linux")]
{
if let Some(fwmark) = options.fwmark {
setsockopt(&socket, sockopt::Mark, &fwmark)
.map_err(ApplyTcpOptionsErrorInternal::Mark)?;
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::Mark, e))?;
}
log::debug!(
"SO_MARK: {}",
getsockopt(&socket, sockopt::Mark).map_err(ApplyTcpOptionsErrorInternal::Mark)?
getsockopt(&socket, sockopt::Mark)
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::Mark, e))?
);
}
Ok(())
Expand All @@ -163,12 +147,12 @@ pub fn set_nodelay(tcp_stream: &TcpStream, nodelay: bool) -> Result<(), ApplyTcp
// Configure TCP_NODELAY on the TCP stream
tcp_stream
.set_nodelay(nodelay)
.map_err(ApplyTcpOptionsErrorInternal::TcpNoDelay)?;
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::TcpNoDelay, e))?;
log::debug!(
"TCP_NODELAY: {}",
tcp_stream
.nodelay()
.map_err(ApplyTcpOptionsErrorInternal::TcpNoDelay)?
.map_err(|e| ApplyTcpOptionsError::new(ApplyTcpOptionsErrorKind::TcpNoDelay, e))?
);
Ok(())
}