Skip to content

Commit

Permalink
chore: allow limiting the size of the uncompressed payload
Browse files Browse the repository at this point in the history
  • Loading branch information
ctron committed Sep 17, 2024
1 parent ea60800 commit cfb9205
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 20 deletions.
67 changes: 62 additions & 5 deletions common/src/compression/detecting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,67 @@ pub enum Compression {
Xz,
}

#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct DecompressionOptions {
/// The maximum decompressed payload size.
///
/// If the size of the uncompressed payload exceeds this limit, and error would be returned
/// instead. Zero means, unlimited.
pub limit: usize,
}

impl DecompressionOptions {
pub fn new() -> Self {
Self::default()
}

/// Set the limit of the maximum uncompressed payload size.
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}

impl Compression {
/// Detect and decompress in a single step.
/// Perform decompression.
///
/// Returns the original data for [`Compression::None`].
pub fn decompress(&self, data: Bytes) -> Result<Bytes, std::io::Error> {
Ok(self.decompress_opt(&data)?.unwrap_or(data))
}

/// Detect and decompress in a single step.
/// Perform decompression.
///
/// Returns the original data for [`Compression::None`].
pub fn decompress_with(
&self,
data: Bytes,
opts: &DecompressionOptions,
) -> Result<Bytes, std::io::Error> {
Ok(self.decompress_opt_with(&data, opts)?.unwrap_or(data))
}

/// Perform decompression.
///
/// Returns `None` for [`Compression::None`]
pub fn decompress_opt(&self, data: &[u8]) -> Result<Option<Bytes>, std::io::Error> {
self.decompress_opt_with(data, &Default::default())
}

/// Perform decompression.
///
/// Returns `None` for [`Compression::None`]
pub fn decompress_opt_with(
&self,
data: &[u8],
opts: &DecompressionOptions,
) -> Result<Option<Bytes>, std::io::Error> {
match self {
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
Compression::Bzip2 => super::decompress_bzip2(data).map(Some),
Compression::Bzip2 => super::decompress_bzip2_with(data, opts).map(Some),
#[cfg(feature = "liblzma")]
Compression::Xz => super::decompress_xz(data).map(Some),
Compression::Xz => super::decompress_xz_with(data, opts).map(Some),
Compression::None => Ok(None),
}
}
Expand All @@ -53,8 +101,17 @@ pub struct Detector<'a> {
impl<'a> Detector<'a> {
/// Detect and decompress in a single step.
pub fn decompress(&'a self, data: Bytes) -> Result<Bytes, Error<'a>> {
self.decompress_with(data, &Default::default())
}

/// Detect and decompress in a single step.
pub fn decompress_with(
&'a self,
data: Bytes,
opts: &DecompressionOptions,
) -> Result<Bytes, Error<'a>> {
let compression = self.detect(&data)?;
Ok(compression.decompress(data)?)
Ok(compression.decompress_with(data, opts)?)
}

pub fn detect(&'a self, #[allow(unused)] data: &[u8]) -> Result<Compression, Error<'a>> {
Expand Down
95 changes: 95 additions & 0 deletions common/src/compression/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::io::{Error, ErrorKind, Write};

/// A writer, limiting the output. Failing if more data is written.
pub struct LimitWriter<W>
where
W: Write,
{
writer: W,
limit: usize,
current: usize,
}

impl<W> LimitWriter<W>
where
W: Write,
{
/// Create a new writer, providing the limit.
pub fn new(writer: W, limit: usize) -> Self {
Self {
writer,
limit,
current: 0,
}
}

/// Close writer, return the inner writer.
///
/// Note: Closing the writer will not flush it before.
pub fn close(self) -> W {
self.writer
}
}

impl<W> Write for LimitWriter<W>
where
W: Write,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
// check what is remaining
let remaining = self.limit.saturating_sub(self.current);
// if noting is left ...
if remaining == 0 {
// ... return an error
return Err(Error::new(ErrorKind::WriteZero, "write limit exceeded"));
}

// write out remaining bytes, maxing out at limit
let to_write = remaining.min(buf.len());
let bytes_written = self.writer.write(&buf[..to_write])?;
self.current += bytes_written;

Ok(bytes_written)
}

fn flush(&mut self) -> std::io::Result<()> {
self.writer.flush()
}
}

#[cfg(test)]
mod test {
use crate::compression::LimitWriter;
use std::io::{Cursor, Write};

fn perform_write(data: &[u8], limit: usize) -> Result<Vec<u8>, std::io::Error> {
let mut out = LimitWriter::new(vec![], limit);
std::io::copy(&mut Cursor::new(data), &mut out)?;
out.flush()?;

Ok(out.close())
}

#[test]
fn write_ok() {
assert!(matches!(
perform_write(b"0123456789", 100).as_deref(),
Ok(b"0123456789")
));
assert!(matches!(perform_write(b"", 100).as_deref(), Ok(b"")));
assert!(matches!(
perform_write(b"0123456789", 10).as_deref(),
Ok(b"0123456789")
));
assert!(matches!(
perform_write(b"012345678", 10).as_deref(),
Ok(b"012345678")
));
}

#[test]
fn write_err() {
assert!(perform_write(b"01234567890", 10).is_err(),);
assert!(perform_write(b"012345678901", 10).is_err(),);
}
}
65 changes: 50 additions & 15 deletions common/src/compression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
//! Helpers for using compression/decompression.

mod detecting;
mod limit;

pub use detecting::*;
pub use limit::*;

use anyhow::anyhow;
use bytes::Bytes;
use std::io::Write;

/// Decompress a bz2 stream, or fail if no encoder was configured.
/// Decompress a stream, or fail if no encoder was configured.
///
/// This function will not consume the data, but return `None`, if no decompression was required.
/// This allows one to hold on to the original, compressed, data if necessary.
Expand All @@ -22,42 +26,73 @@ pub fn decompress_opt(data: &[u8], name: &str) -> Option<Result<Bytes, anyhow::E
.transpose()
}

/// Decompress a bz2 stream, or fail if no encoder was configured.
/// Decompress a stream, or fail if no encoder was configured.
pub fn decompress(data: Bytes, name: &str) -> Result<Bytes, anyhow::Error> {
decompress_opt(&data, name).unwrap_or_else(|| Ok(data))
}

/// Decompress bz2 using `bzip2-rs` (pure Rust version)
#[cfg(all(feature = "bzip2-rs", not(feature = "bzip2")))]
#[deprecated(since = "0.9.3", note = "Use Compression::decompress instead")]
pub fn decompress_bzip2(data: &[u8]) -> Result<Bytes, std::io::Error> {
use std::io::Read;
decompress_bzip2_with(data, &DecompressionOptions::default())
}

/// Decompress bz2 using `bzip2-rs` (pure Rust version)
#[cfg(all(feature = "bzip2-rs", not(feature = "bzip2")))]
#[deprecated(since = "0.9.3", note = "Use Compression::decompress instead")]
pub fn decompress_bzip2_with(
data: &[u8],
opts: &DecompressionOptions,
) -> Result<Bytes, std::io::Error> {
let mut decoder = bzip2_rs::DecoderReader::new(data);
let mut data = vec![];
decoder.read_to_end(&mut data)?;
Ok(Bytes::from(data))
decompress_limit(decoder, opts.limit)
}

/// Decompress bz2 using `bzip2` (based on `libbz2`).
#[cfg(feature = "bzip2")]
#[deprecated(since = "0.9.3", note = "Use Compression::decompress instead")]
pub fn decompress_bzip2(data: &[u8]) -> Result<Bytes, std::io::Error> {
use std::io::Read;

let mut decoder = bzip2::read::BzDecoder::new(data);
let mut data = vec![];
decoder.read_to_end(&mut data)?;
decompress_bzip2_with(data, &DecompressionOptions::default())
}

Ok(Bytes::from(data))
/// Decompress bz2 using `bzip2` (based on `libbz2`).
#[cfg(feature = "bzip2")]
fn decompress_bzip2_with(
data: &[u8],
opts: &DecompressionOptions,
) -> Result<Bytes, std::io::Error> {
let decoder = bzip2::read::BzDecoder::new(data);
decompress_limit(decoder, opts.limit)
}

/// Decompress xz using `liblzma`.
#[cfg(feature = "liblzma")]
#[deprecated(since = "0.9.3", note = "Use Compression::decompress instead")]
pub fn decompress_xz(data: &[u8]) -> Result<Bytes, std::io::Error> {
use std::io::Read;
decompress_xz_with(data, &Default::default())
}

let mut decoder = liblzma::read::XzDecoder::new(data);
/// Decompress xz using `liblzma`.
#[cfg(feature = "liblzma")]
fn decompress_xz_with(data: &[u8], opts: &DecompressionOptions) -> Result<Bytes, std::io::Error> {
let decoder = liblzma::read::XzDecoder::new(data);
decompress_limit(decoder, opts.limit)
}

/// Decompress with an uncompressed payload limit.
fn decompress_limit(mut reader: impl std::io::Read, limit: usize) -> Result<Bytes, std::io::Error> {
let mut data = vec![];
decoder.read_to_end(&mut data)?;

let data = if limit > 0 {
let mut writer = LimitWriter::new(data, limit);
std::io::copy(&mut reader, &mut writer)?;
writer.flush()?;
writer.close()
} else {
reader.read_to_end(&mut data)?;
data
};

Ok(Bytes::from(data))
}
Binary file added common/tests/data/bomb.bz2
Binary file not shown.
19 changes: 19 additions & 0 deletions common/tests/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use bytes::Bytes;
use walker_common::compression::{Compression, DecompressionOptions};

/// Test the case of having an unreasonably large decompressed size.
///
/// The idea is to have a compressed file which, by itself, has an acceptable size. However, which
/// decompresses into an unreasonable large payload. This should be prevented by applying a limit
/// to the decompression.
#[test]
#[cfg(any(feature = "bzip2", feature = "bzip2-rs"))]
fn bz2ip_bomb() {
let data = include_bytes!("data/bomb.bz2");
let result = Compression::Bzip2.decompress_with(
Bytes::from_static(data),
&DecompressionOptions::new().limit(1024 * 1024),
);

assert!(result.is_err())
}

0 comments on commit cfb9205

Please sign in to comment.