diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d4aef7..8123da9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.0](https://github.com/spiraldb/fsst/compare/v0.2.3...v0.3.0) - 2024-09-03 + +### Added +- port in more from the C++ code ([#24](https://github.com/spiraldb/fsst/pull/24)) + +### Other +- centering ([#26](https://github.com/spiraldb/fsst/pull/26)) + ## [0.2.3](https://github.com/spiraldb/fsst/compare/v0.2.2...v0.2.3) - 2024-08-22 ### Added diff --git a/Cargo.lock b/Cargo.lock index 774db39..a24b1d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "cc" +version = "1.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932" +dependencies = [ + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -166,6 +175,36 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "curl" +version = "0.4.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e2161dd6eba090ff1594084e95fd67aeccf04382ffea77999ea94ed42ec67b6" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe", + "openssl-sys", + "schannel", + "socket2", + "windows-sys 0.52.0", +] + +[[package]] +name = "curl-sys" +version = "0.4.74+curl-8.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8af10b986114528fcdc4b63b6f5f021b7057618411046a4de2ba0f0149a097bf" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "windows-sys 0.52.0", +] + [[package]] name = "either" version = "1.13.0" @@ -174,9 +213,10 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "fsst-rs" -version = "0.2.3" +version = "0.3.0" dependencies = [ "criterion", + "curl", ] [[package]] @@ -236,6 +276,18 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libz-sys" +version = "1.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc53a7799a7496ebc9fd29f31f7df80e83c9bda5299768af5f9e59eeea74647" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "log" version = "0.4.22" @@ -269,6 +321,30 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "plotters" version = "0.3.6" @@ -379,6 +455,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "serde" version = "1.0.206" @@ -411,6 +496,22 @@ dependencies = [ "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "socket2" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "syn" version = "2.0.74" @@ -438,6 +539,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "walkdir" version = "2.5.0" diff --git a/Cargo.toml b/Cargo.toml index e16235f..92d1c63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsst-rs" -version = "0.2.3" +version = "0.3.0" description = "Pure-Rust implementation of Fast Static Symbol Tables algorithm for string compression" authors = ["SpiralDB Developers "] license = "Apache-2.0" @@ -27,6 +27,7 @@ use_debug = { level = "deny" } [dev-dependencies] criterion = "0.5" +curl = "0.4" [[example]] name = "round_trip" @@ -37,6 +38,10 @@ test = false name = "compress" harness = false +[[bench]] +name = "micro" +harness = false + [[test]] name = "correctness" test = true diff --git a/README.md b/README.md index 9570808..6daae15 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ - +

+ +

![Crates.io Version](https://img.shields.io/crates/v/fsst_rs) ![docs.rs](https://img.shields.io/docsrs/fsst-rs) diff --git a/benches/.gitignore b/benches/.gitignore new file mode 100644 index 0000000..8fce603 --- /dev/null +++ b/benches/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/benches/compress.rs b/benches/compress.rs index c9ff5af..8a26e50 100644 --- a/benches/compress.rs +++ b/benches/compress.rs @@ -1,56 +1,122 @@ //! Benchmarks for FSST compression, decompression, and symbol table training. +//! +//! We use the dbtext data at https://github.com/cwida/fsst/tree/master/paper/dbtext #![allow(missing_docs)] use core::str; +use std::{ + error::Error, + fs::{self, DirBuilder, File}, + io::{Read, Write}, + path::Path, +}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use fsst::{Compressor, ESCAPE_CODE}; +use curl::easy::Easy; +use fsst::Compressor; -const CORPUS: &str = include_str!("dracula.txt"); -const TEST: &str = "I found my smattering of German very useful here"; +fn download_dataset(url: &str, path: impl AsRef) -> Result<(), Box> { + let target = path.as_ref(); -fn bench_fsst(c: &mut Criterion) { - let mut group = c.benchmark_group("fsst"); - group.bench_function("train", |b| { - let corpus = CORPUS.as_bytes(); - b.iter(|| black_box(Compressor::train(black_box(corpus)))); - }); + let mut dir_builder = DirBuilder::new(); + dir_builder.recursive(true); - let compressor = Compressor::train(CORPUS); - let plaintext = TEST.as_bytes(); + dir_builder.create(target.parent().unwrap())?; - let compressed = compressor.compress(plaintext); - let escape_count = compressed.iter().filter(|b| **b == ESCAPE_CODE).count(); - let ratio = (plaintext.len() as f64) / (compressed.len() as f64); - println!( - "Escapes = {escape_count}/{}, compression_ratio = {ratio}", - compressed.len() + // Avoid downloading the file twice. + if target.exists() { + return Ok(()); + } + + let mut handle = Easy::new(); + + let mut buffer = Vec::new(); + handle.url(url)?; + { + let mut transfer = handle.transfer(); + transfer.write_function(|data| { + buffer.extend_from_slice(data); + + Ok(data.len()) + })?; + transfer.perform()?; + } + + let mut output = File::create(target)?; + match output.write_all(&buffer) { + Ok(()) => {} + Err(err) => { + // cleanup in case of failure + fs::remove_file(target).unwrap(); + + return Err(Box::new(err)); + } + } + + Ok(()) +} + +#[allow(clippy::use_debug)] +fn bench_dbtext(c: &mut Criterion) { + fn run_dataset_bench(name: &str, url: &str, path: &str, c: &mut Criterion) { + let mut group = c.benchmark_group(name); + download_dataset(url, path).unwrap(); + + let mut buf = Vec::new(); + { + let mut file = File::open(path).unwrap(); + file.read_to_end(&mut buf).unwrap(); + } + + group.bench_function("train-and-compress", |b| { + b.iter_with_large_drop(|| { + let compressor = Compressor::train(&vec![&buf]); + compressor.compress_bulk(std::hint::black_box(&vec![&buf])) + }); + }); + + let compressor = Compressor::train(&vec![&buf]); + let mut buffer = Vec::with_capacity(200 * 1024 * 1024); + group.throughput(Throughput::Bytes(buf.len() as u64)); + group.bench_function("compress-only", |b| { + b.iter(|| unsafe { compressor.compress_into(&buf, &mut buffer) }); + }); + + group.finish(); + + // Report the compression factor for this dataset. + let uncompressed_size = buf.len(); + let compressor = Compressor::train(&vec![&buf]); + + let compressed = compressor.compress_bulk(&vec![&buf]); + let compressed_size = compressed.iter().map(|l| l.len()).sum::(); + let cf = (uncompressed_size as f64) / (compressed_size as f64); + println!( + "compressed {name} {uncompressed_size} => {compressed_size}B (compression factor {cf:.2}:1)" + ) + } + + run_dataset_bench( + "dbtext/wikipedia", + "https://raw.githubusercontent.com/cwida/fsst/4e188a/paper/dbtext/wikipedia", + "benches/data/wikipedia", + c, + ); + + run_dataset_bench( + "dbtext/l_comment", + "https://raw.githubusercontent.com/cwida/fsst/4e188a/paper/dbtext/l_comment", + "benches/data/l_comment", + c, ); - let decompressor = compressor.decompressor(); - let decompressed = decompressor.decompress(&compressed); - let decompressed = str::from_utf8(&decompressed).unwrap(); - - group.throughput(Throughput::Elements(1)); - group.bench_function("compress-word", |b| { - let mut out = vec![0u8; 8]; - let out_ptr = out.as_mut_ptr(); - let front = &TEST.as_bytes()[0..8]; - let word = u64::from_le_bytes(front.try_into().unwrap()); - - b.iter(|| black_box(unsafe { compressor.compress_word(word, out_ptr) })); - }); - - group.throughput(Throughput::Bytes(CORPUS.len() as u64)); - group.bench_function("compress-single", |b| { - b.iter(|| black_box(compressor.compress(black_box(CORPUS.as_bytes())))); - }); - - group.throughput(Throughput::Bytes(decompressed.len() as u64)); - group.bench_function("decompress-single", |b| { - b.iter(|| black_box(decompressor.decompress(black_box(&compressed)))); - }); + run_dataset_bench( + "dbtext/urls", + "https://raw.githubusercontent.com/cwida/fsst/4e188a/paper/dbtext/urls", + "benches/data/urls", + c, + ); } -criterion_group!(compress_bench, bench_fsst); +criterion_group!(compress_bench, bench_dbtext); criterion_main!(compress_bench); diff --git a/benches/dracula.txt b/benches/dracula.txt deleted file mode 100644 index 88adb22..0000000 --- a/benches/dracula.txt +++ /dev/null @@ -1 +0,0 @@ -How these papers have been placed in sequence will be made manifest in the reading of them. All needless matters have been eliminated, so that a history almost at variance with the possibilities of later-day belief may stand forth as simple fact. There is throughout no statement of past things wherein memory may err, for all the records chosen are exactly contemporary, given from the standpoints and within the range of knowledge of those who made them. We left in pretty good time, and came after nightfall to Klausenburgh. Here I stopped for the night at the Hotel Royale. I had for dinner, or rather supper, a chicken done up some way with red pepper, which was very good but thirsty. (Mem., get recipe for Mina.) I asked the waiter, and he said it was called “paprika hendl,” and that, as it was a national dish, I should be able to get it anywhere along the Carpathians. I found my smattering of German very useful here; indeed, I don’t know how I should be able to get on without it. diff --git a/benches/micro.rs b/benches/micro.rs new file mode 100644 index 0000000..d55402e --- /dev/null +++ b/benches/micro.rs @@ -0,0 +1,112 @@ +#![allow(missing_docs)] + +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; + +use fsst::{CompressorBuilder, Symbol}; + +fn one_megabyte(seed: &[u8]) -> Vec { + seed.iter().copied().cycle().take(1024 * 1024).collect() +} + +fn bench_compress(c: &mut Criterion) { + let mut group = c.benchmark_group("compress-overhead"); + // Reusable memory to hold outputs + let mut output_buf: Vec = Vec::with_capacity(8 * 1024 * 1024); + + // We create a symbol table that requires probing the hash table to perform + // decompression. + group.bench_function("compress-hashtab", |b| { + let mut compressor = CompressorBuilder::new(); + compressor.insert(Symbol::from_slice(b"abcdefgh"), 8); + let compressor = compressor.build(); + + let word = u64::from_le_bytes([b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']); + b.iter(|| unsafe { compressor.compress_word(word, output_buf.as_mut_ptr()) }); + }); + + // We create a symbol table that is able to short-circuit the decompression + group.bench_function("compress-twobytes", |b| { + let mut compressor = CompressorBuilder::new(); + compressor.insert(Symbol::from_slice(&[b'a', b'b', 0, 0, 0, 0, 0, 0]), 2); + let compressor = compressor.build(); + + let word = u64::from_le_bytes([b'a', b'b', 0, 0, 0, 0, 0, 0]); + b.iter(|| unsafe { compressor.compress_word(word, output_buf.as_mut_ptr()) }); + }); + group.finish(); + + let mut group = c.benchmark_group("cf=1"); + let test_string = one_megabyte(b"aaaaaaaa"); + group.throughput(Throughput::Bytes(test_string.len() as u64)); + group.bench_function("compress", |b| { + let mut compressor = CompressorBuilder::new(); + assert!(compressor.insert(Symbol::from_u8(b'a'), 1)); + let compressor = compressor.build(); + + b.iter(|| unsafe { + compressor.compress_into(&test_string, &mut output_buf); + }) + }); + group.finish(); + + let mut group = c.benchmark_group("cf=2"); + let test_string = one_megabyte(b"ab"); + + group.throughput(Throughput::Bytes(test_string.len() as u64)); + group.bench_function("compress", |b| { + let mut compressor = CompressorBuilder::new(); + // This outputs two codes for every 4 bytes of text. + assert!(compressor.insert(Symbol::from_slice(&[b'a', 0, 0, 0, 0, 0, 0, 0]), 1)); + assert!(compressor.insert(Symbol::from_slice(&[b'b', b'a', b'b', 0, 0, 0, 0, 0]), 3)); + let compressor = compressor.build(); + + b.iter(|| unsafe { + compressor.compress_into(&test_string, &mut output_buf); + }) + }); + group.finish(); + + let mut group = c.benchmark_group("cf=4"); + let test_string = one_megabyte(b"abcd"); + group.throughput(Throughput::Bytes(test_string.len() as u64)); + group.bench_function("compress", |b| { + let mut compressor = CompressorBuilder::new(); + assert!(compressor.insert(Symbol::from_slice(&[b'a', b'b', b'c', b'd', 0, 0, 0, 0]), 4)); + let compressor = compressor.build(); + + b.iter(|| unsafe { + compressor.compress_into(&test_string, &mut output_buf); + }) + }); + group.finish(); + + let mut group = c.benchmark_group("cf=8"); + let test_string = one_megabyte(b"abcdefgh"); + group.throughput(Throughput::Bytes(test_string.len() as u64)); + group.bench_function("compress", |b| { + let mut compressor = CompressorBuilder::new(); + assert!(compressor.insert(Symbol::from_slice(b"abcdefgh"), 8)); + let compressor = compressor.build(); + + b.iter(|| unsafe { + compressor.compress_into(&test_string, &mut output_buf); + }) + }); + + group.bench_function("decompress", |b| { + let mut compressor = CompressorBuilder::new(); + assert!(compressor.insert(Symbol::from_slice(b"abcdefgh"), 8)); + let compressor = compressor.build(); + let compressed = compressor.compress(&test_string); + + let decompressor = compressor.decompressor(); + + b.iter(|| decompressor.decompress(&compressed)) + }); + group.finish(); + + let _ = std::hint::black_box(output_buf); +} + +criterion_group!(bench_micro, bench_compress); +criterion_main!(bench_micro); diff --git a/examples/file_compressor.rs b/examples/file_compressor.rs index 3314c92..7ae27bf 100644 --- a/examples/file_compressor.rs +++ b/examples/file_compressor.rs @@ -1,21 +1,17 @@ #![allow(missing_docs, clippy::use_debug)] -//! This is a command line program that expects two input files as arguments. -//! -//! The first is the file to train a symbol table on. -//! -//! The second is the file to compress. The compressor will run and compress -//! in chunks of 16MB, logging the compression ratio for each chunk. +//! This is a command line program that expects an input file as an argument, +//! and trains a symbol table that it then uses to compress the file in-memory. //! //! Example: //! //! ``` -//! cargo run --release --example file_compressor -- file1.csv file2.csv +//! cargo run --release --example file_compressor -- lineitem.tbl //! ``` use std::{ fs::File, io::Read, - os::unix::fs::{FileExt, MetadataExt}, + // io::{Read, Write}, path::Path, }; @@ -23,50 +19,37 @@ use fsst::Compressor; fn main() { let args: Vec<_> = std::env::args().skip(1).collect(); - assert!(args.len() >= 2, "args TRAINING and FILE must be provided"); - let train_path = Path::new(&args[0]); - let input_path = Path::new(&args[1]); + let input_path = Path::new(&args[0]); - let mut train_bytes = Vec::new(); + let mut string = String::new(); { - let mut f = File::open(train_path).unwrap(); - f.read_to_end(&mut train_bytes).unwrap(); - } - - println!("building the compressor from {train_path:?}..."); - let compressor = Compressor::train(&train_bytes); - - println!("compressing blocks of {input_path:?} with compressor..."); - - let f = File::open(input_path).unwrap(); - let size_bytes = f.metadata().unwrap().size() as usize; - - const CHUNK_SIZE: usize = 16 * 1024 * 1024; - - let mut chunk_idx = 1; - let mut pos = 0; - let mut chunk = vec![0u8; CHUNK_SIZE]; - while pos + CHUNK_SIZE < size_bytes { - f.read_exact_at(&mut chunk, pos as u64).unwrap(); - // Compress the chunk, don't write it anywhere. - let compact = compressor.compress(&chunk); - let compression_ratio = (CHUNK_SIZE as f64) / (compact.len() as f64); - println!("compressed chunk {chunk_idx} with ratio {compression_ratio}"); - - pos += CHUNK_SIZE; - chunk_idx += 1; + let mut f = File::open(input_path).unwrap(); + f.read_to_string(&mut string).unwrap(); } - - // Read last chunk with a new custom-sized buffer. - if pos < size_bytes { - let amount = size_bytes - pos; - chunk = vec![0u8; size_bytes - pos]; - f.read_exact_at(&mut chunk, pos as u64).unwrap(); - // Compress the chunk, don't write it anywhere. - let compact = compressor.compress(&chunk[0..amount]); - let compression_ratio = (amount as f64) / (compact.len() as f64); - println!("compressed chunk {chunk_idx} with ratio {compression_ratio}"); + let uncompressed_size = string.as_bytes().len(); + let lines: Vec<&[u8]> = string.lines().map(|line| line.as_bytes()).collect(); + + // let mut output = File::create(output_path).unwrap(); + let start = std::time::Instant::now(); + let compressor = Compressor::train(&lines); + let duration = std::time::Instant::now().duration_since(start); + println!("train took {}µs", duration.as_micros()); + let mut compressed_size = 0; + + let mut buffer = Vec::with_capacity(8 * 1024 * 1024); + + let start = std::time::Instant::now(); + for text in lines { + unsafe { compressor.compress_into(text, &mut buffer) }; + compressed_size += buffer.len(); } - println!("done"); + let duration = std::time::Instant::now().duration_since(start); + println!("compression took {}µs", duration.as_micros()); + println!( + "compressed {} -> {} ({}%)", + uncompressed_size, + compressed_size, + 100.0 * (compressed_size as f64) / (uncompressed_size as f64) + ); } diff --git a/examples/round_trip.rs b/examples/round_trip.rs index 038b932..7044f72 100644 --- a/examples/round_trip.rs +++ b/examples/round_trip.rs @@ -7,7 +7,7 @@ use fsst::Compressor; fn main() { // Train on a sample. let sample = "the quick brown fox jumped over the lazy dog"; - let trained = Compressor::train(sample.as_bytes()); + let trained = Compressor::train(&vec![sample.as_bytes()]); let compressed = trained.compress(sample.as_bytes()); println!("compressed: {} => {}", sample.len(), compressed.len()); // decompress now diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 8c0cc9f..4c0ea3c 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -21,7 +21,7 @@ dependencies = [ [[package]] name = "fsst-rs" -version = "0.1.0" +version = "0.2.3" [[package]] name = "fsst-rs-fuzz" diff --git a/fuzz/fuzz_targets/fuzz_compress.rs b/fuzz/fuzz_targets/fuzz_compress.rs index a871293..50e9d31 100644 --- a/fuzz/fuzz_targets/fuzz_compress.rs +++ b/fuzz/fuzz_targets/fuzz_compress.rs @@ -4,7 +4,7 @@ use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { let compressor = - fsst::Compressor::train("the quick brown fox jumped over the lazy dog".as_bytes()); + fsst::Compressor::train(&vec![b"the quick brown fox jumped over the lazy dog"]); let compress = compressor.compress(data); let decompress = compressor.decompressor().decompress(&compress); assert_eq!(&decompress, data); diff --git a/fuzz/fuzz_targets/fuzz_train.rs b/fuzz/fuzz_targets/fuzz_train.rs index 5d3dada..18581e1 100644 --- a/fuzz/fuzz_targets/fuzz_train.rs +++ b/fuzz/fuzz_targets/fuzz_train.rs @@ -3,5 +3,5 @@ use libfuzzer_sys::fuzz_target; fuzz_target!(|data: &[u8]| { - let _ = fsst::Compressor::train(data); + let _ = fsst::Compressor::train(&vec![data]); }); diff --git a/src/builder.rs b/src/builder.rs index fefc55d..ed24874 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -7,7 +7,10 @@ use std::cmp::Ordering; use std::collections::BinaryHeap; -use crate::{CodeMeta, Compressor, Symbol, ESCAPE_CODE, MAX_CODE}; +use crate::{ + advance_8byte_word, compare_masked, lossy_pht::LossyPHT, Code, Compressor, Symbol, + FSST_CODE_BASE, FSST_CODE_MASK, +}; /// Bitmap that only works for values up to 512 #[derive(Clone, Copy, Debug, Default)] @@ -18,9 +21,12 @@ struct CodesBitmap { assert_sizeof!(CodesBitmap => 64); impl CodesBitmap { - /// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE]. + /// Set the indicated bit. Must be between 0 and [`FSST_CODE_MASK`][crate::FSST_CODE_MASK]. pub(crate) fn set(&mut self, index: usize) { - debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); + debug_assert!( + index <= FSST_CODE_MASK as usize, + "code cannot exceed {FSST_CODE_MASK}" + ); let map = index >> 6; self.codes[map] |= 1 << (index % 64); @@ -28,7 +34,10 @@ impl CodesBitmap { /// Check if `index` is present in the bitmap pub(crate) fn is_set(&self, index: usize) -> bool { - debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}"); + debug_assert!( + index <= FSST_CODE_MASK as usize, + "code cannot exceed {FSST_CODE_MASK}" + ); let map = index >> 6; self.codes[map] & 1 << (index % 64) != 0 @@ -82,6 +91,10 @@ impl<'a> Iterator for CodesIterator<'a> { let position = self.block.trailing_zeros() as usize; let code = self.reference + position; + if code >= 511 { + return None; + } + // The next iteration will calculate with reference to the returned code + 1 self.reference = code + 1; self.block = if position == 63 { @@ -112,9 +125,13 @@ struct Counter { pair_index: Vec, } -const COUNTS1_SIZE: usize = MAX_CODE as usize; +const COUNTS1_SIZE: usize = (FSST_CODE_MASK + 1) as usize; + // NOTE: in Rust, creating a 1D vector of length N^2 is ~4x faster than creating a 2-D vector, // because `vec!` has a specialization for zero. +// +// We also include +1 extra row at the end so that we can do writes into the counters without a branch +// for the first iteration. const COUNTS2_SIZE: usize = COUNTS1_SIZE * COUNTS1_SIZE; impl Counter { @@ -138,20 +155,23 @@ impl Counter { #[inline] fn record_count1(&mut self, code1: u16) { - if self.code1_index.is_set(code1 as usize) { - self.counts1[code1 as usize] += 1; + // If not set, we want to start at one. + let base = if self.code1_index.is_set(code1 as usize) { + self.counts1[code1 as usize] } else { - self.counts1[code1 as usize] = 1; - } + 0 + }; + + self.counts1[code1 as usize] = base + 1; self.code1_index.set(code1 as usize); } #[inline] fn record_count2(&mut self, code1: u16, code2: u16) { - debug_assert!(self.code1_index.is_set(code1 as usize)); + debug_assert!(code1 == FSST_CODE_MASK || self.code1_index.is_set(code1 as usize)); debug_assert!(self.code1_index.is_set(code2 as usize)); - let idx = (code1 as usize) * 511 + (code2 as usize); + let idx = (code1 as usize) * COUNTS1_SIZE + (code2 as usize); if self.pair_index[code1 as usize].is_set(code2 as usize) { self.counts2[idx] += 1; } else { @@ -173,7 +193,7 @@ impl Counter { debug_assert!(self.code1_index.is_set(code2 as usize)); debug_assert!(self.pair_index[code1 as usize].is_set(code2 as usize)); - let idx = (code1 as usize) * 511 + (code2 as usize); + let idx = (code1 as usize) * 512 + (code2 as usize); self.counts2[idx] } @@ -202,15 +222,123 @@ impl Counter { } } -/// The number of generations used for training. This is taken from the [FSST paper]. -/// -/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf -#[cfg(not(miri))] -const MAX_GENERATIONS: usize = 5; -#[cfg(miri)] -const MAX_GENERATIONS: usize = 2; +/// Entrypoint for building a new `Compressor`. +pub struct CompressorBuilder { + /// Table mapping codes to symbols. + /// + /// The entries 0-255 are setup in some other way here + symbols: Vec, + + /// The number of entries in the symbol table that have been populated, not counting + /// the escape values. + n_symbols: u8, + + /// Counts for number of symbols of each length. + /// + /// `len_histogram[len-1]` = count of the symbols of length `len`. + len_histogram: [u8; 8], + + /// Inverted index mapping 1-byte symbols to codes. + /// + /// This is only used for building, not used by the final `Compressor`. + codes_one_byte: Vec, + + /// Inverted index mapping 2-byte symbols to codes + codes_two_byte: Vec, + + /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more + lossy_pht: LossyPHT, +} + +impl CompressorBuilder { + /// Create a new builder. + pub fn new() -> Self { + // NOTE: `vec!` has a specialization for building a new vector of `0u64`. Because Symbol and u64 + // have the same bit pattern, we can allocate as u64 and transmute. If we do `vec![Symbol::EMPTY; N]`, + // that will create a new Vec and call `Symbol::EMPTY.clone()` `N` times which is considerably slower. + let symbols = vec![0u64; 511]; + + // SAFETY: transmute safety assured by the compiler. + let symbols: Vec = unsafe { std::mem::transmute(symbols) }; + + let mut table = Self { + symbols, + n_symbols: 0, + len_histogram: [0; 8], + codes_two_byte: Vec::with_capacity(65_536), + codes_one_byte: Vec::with_capacity(512), + lossy_pht: LossyPHT::new(), + }; + + // Populate the escape byte entries. + for byte in 0..=255 { + let symbol = Symbol::from_u8(byte); + table.symbols[byte as usize] = symbol; + } + + // Fill codes_one_byte with pseudocodes for each byte. + for byte in 0..=255 { + // Push pseudocode for single-byte escape. + table.codes_one_byte.push(Code::new_escape(byte)); + } + + // Fill codes_two_byte with pseudocode of first byte + for byte1 in 0..=255 { + for _byte2 in 0..=255 { + table.codes_two_byte.push(Code::new_escape(byte1)); + } + } + + table + } +} + +impl Default for CompressorBuilder { + fn default() -> Self { + Self::new() + } +} + +impl CompressorBuilder { + /// Attempt to insert a new symbol at the end of the table. + /// + /// # Panics + /// + /// Panics if the table is already full. + /// + /// # Returns + /// + /// Returns true if the symbol was inserted successfully, or false if it conflicted + /// with an existing symbol. + pub fn insert(&mut self, symbol: Symbol, len: usize) -> bool { + assert!(self.n_symbols < 255, "cannot insert into full symbol table"); + debug_assert!(len == symbol.len(), "provided len != symbol.len()"); + + if len == 2 { + // shortCodes + self.codes_two_byte[symbol.first2() as usize] = + Code::new_symbol_building(self.n_symbols, 2); + } else if len == 1 { + // byteCodes + self.codes_one_byte[symbol.first_byte() as usize] = + Code::new_symbol_building(self.n_symbols, 1); + } else { + // Symbols of 3 or more bytes go into the hash table + if !self.lossy_pht.insert(symbol, len, self.n_symbols) { + return false; + } + } + + // Increment length histogram. + self.len_histogram[len - 1] += 1; + + // Insert successfully stored symbol at end of the symbol table + // Note the rescaling from range [0-254] -> [256, 510]. + self.symbols[256 + (self.n_symbols as usize)] = symbol; + self.n_symbols += 1; + true + } -impl Compressor { /// Clear all set items from the compressor. /// /// This is considerably faster than building a new Compressor from scratch for each @@ -219,18 +347,253 @@ impl Compressor { // Eliminate every observed code from the table. for code in 0..(256 + self.n_symbols as usize) { let symbol = self.symbols[code]; - if symbol.len() <= 2 { - // Clear the codes_twobyte array - self.codes_twobyte[symbol.first_two_bytes() as usize] = CodeMeta::EMPTY; + if symbol.len() == 1 { + // Reset the entry from the codes_one_byte array. + self.codes_one_byte[symbol.first_byte() as usize] = + Code::new_escape(symbol.first_byte()); + } else if symbol.len() == 2 { + // Reset the entry from the codes_two_byte array. + self.codes_two_byte[symbol.first2() as usize] = + Code::new_escape(symbol.first_byte()); } else { - // Clear the hashtable + // Clear the hashtable entry self.lossy_pht.remove(symbol); } } + // Reset len histogram + for i in 0..=7 { + self.len_histogram[i] = 0; + } + self.n_symbols = 0; } + /// Finalizing the table is done once building is complete to prepare for efficient + /// compression. + /// + /// When we finalize the table, the following modifications are made in-place: + /// + /// 1. The codes are renumbered so that all symbols are ordered by length (order 23456781). + /// During this process, the two byte symbols are separated into a byte_lim and a suffix_lim, + /// so we know that we don't need to check the suffix limitations instead. + /// 2. The 1-byte symbols index is merged into the 2-byte symbols index to allow for use of only + /// a single index in front of the hash table. + /// + /// # Returns + /// + /// Returns the `suffix_lim`, which is the index of the two-byte code before where we know + /// there are no longer suffixies in the symbol table. + /// + /// Also returns the lengths vector, which is of length `n_symbols` and contains the + /// length for each of the values. + #[inline(never)] + fn finalize(&mut self) -> (u8, Vec) { + // Create a cumulative sum of each of the elements of the input line numbers. + // Do a map that includes the previously seen value as well. + // Regroup symbols based on their lengths. + // Space at the end of the symbol table reserved for the one-byte codes. + let byte_lim = self.n_symbols - self.len_histogram[0]; + + // Start code for each length. + // Length 1: at the end of symbol table. + // Length 2: starts at 0. Split into before/after suffixLim. + let mut codes_by_length = [0u8; 8]; + codes_by_length[0] = byte_lim; + codes_by_length[1] = 0; + + // codes for lengths 3..=8 start where the previous ones end. + for i in 1..7 { + codes_by_length[i + 1] = codes_by_length[i] + self.len_histogram[i]; + } + + // no_suffix_code is the lowest code for a symbol that does not have a longer 3+ byte + // suffix in the table. + // This value starts at 0 and extends up. + let mut no_suffix_code = 0; + + // The codes that do not have a suffix begin just before the range of the 3-byte codes. + let mut has_suffix_code = codes_by_length[2]; + + // Assign each symbol a new code ordered by lengths, in the order + // 2(no suffix) | 2 (suffix) | 3 | 4 | 5 | 6 | 7 | 8 | 1 + let mut new_codes = [0u8; FSST_CODE_BASE as usize]; + + let mut symbol_lens = [0u8; FSST_CODE_BASE as usize]; + + for i in 0..(self.n_symbols as usize) { + let symbol = self.symbols[256 + i]; + let len = symbol.len(); + if len == 2 { + let has_suffix = self + .symbols + .iter() + .skip(FSST_CODE_BASE as usize) + .enumerate() + .any(|(k, other)| i != k && symbol.first2() == other.first2()); + + if has_suffix { + // Symbols that have a longer suffix are inserted at the end of the 2-byte range + has_suffix_code -= 1; + new_codes[i] = has_suffix_code; + } else { + // Symbols that do not have a longer suffix are inserted at the start of + // the 2-byte range. + new_codes[i] = no_suffix_code; + no_suffix_code += 1; + } + } else { + // Assign new code based on the next code available for the given length symbol + new_codes[i] = codes_by_length[len - 1]; + codes_by_length[len - 1] += 1; + } + + // Write the symbol into the front half of the symbol table. + // We are reusing the space that was previously occupied by escapes. + self.symbols[new_codes[i] as usize] = symbol; + symbol_lens[new_codes[i] as usize] = len as u8; + } + + // Truncate the symbol table to only include the "true" symbols. + self.symbols.truncate(self.n_symbols as usize); + + // Rewrite the codes_one_byte table to point at the new code values. + // Replace pseudocodes with escapes. + for byte in 0..=255 { + let one_byte = self.codes_one_byte[byte]; + if one_byte.extended_code() >= FSST_CODE_BASE { + let new_code = new_codes[one_byte.code() as usize]; + self.codes_one_byte[byte] = Code::new_symbol(new_code, 1); + } else { + // After finalize: codes_one_byte contains the unused value + self.codes_one_byte[byte] = Code::UNUSED; + } + } + + // Rewrite the codes_two_byte table to point at the new code values. + // Replace pseudocodes with escapes. + for two_bytes in 0..=65_535 { + let two_byte = self.codes_two_byte[two_bytes]; + if two_byte.extended_code() >= FSST_CODE_BASE { + let new_code = new_codes[two_byte.code() as usize]; + self.codes_two_byte[two_bytes] = Code::new_symbol(new_code, 2); + } else { + // The one-byte code for the given code number here... + let new_code = self.codes_one_byte[two_bytes as u8 as usize]; + self.codes_two_byte[two_bytes] = new_code; + } + } + + // Reset values in the hash table as well. + self.lossy_pht.renumber(&new_codes); + + // Pre-compute the lengths + let mut lengths = Vec::with_capacity(self.n_symbols as usize); + for symbol in &self.symbols { + lengths.push(symbol.len() as u8); + } + + (has_suffix_code, lengths) + } + + /// Build into the final hash table. + pub fn build(mut self) -> Compressor { + // finalize the symbol table by inserting the codes_twobyte values into + // the relevant parts of the `codes_onebyte` set. + + let (has_suffix_code, lengths) = self.finalize(); + + Compressor { + symbols: self.symbols, + lengths, + n_symbols: self.n_symbols, + has_suffix_code, + codes_two_byte: self.codes_two_byte, + lossy_pht: self.lossy_pht, + } + } +} + +/// The number of generations used for training. This is taken from the [FSST paper]. +/// +/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf +#[cfg(not(miri))] +const GENERATIONS: [usize; 5] = [8usize, 38, 68, 98, 128]; +#[cfg(miri)] +const GENERATIONS: [usize; 3] = [8usize, 38, 128]; + +const FSST_SAMPLETARGET: usize = 1 << 14; +const FSST_SAMPLEMAX: usize = 1 << 15; +const FSST_SAMPLELINE: usize = 512; + +/// Create a sample from a set of strings in the input. +/// +/// Sample is constructing by copying "chunks" from the `str_in`s into the `sample_buf`, the +/// returned slices are pointers into the `sample_buf`. +/// +/// SAFETY: sample_buf must be >= FSST_SAMPLEMAX bytes long. Providing something less may cause unexpected failures. +#[allow(clippy::ptr_arg)] +fn make_sample<'a, 'b: 'a>(sample_buf: &'a mut Vec, str_in: &Vec<&'b [u8]>) -> Vec<&'a [u8]> { + debug_assert!( + sample_buf.capacity() >= FSST_SAMPLEMAX, + "sample_buf.len() < FSST_SAMPLEMAX" + ); + + let mut sample: Vec<&[u8]> = Vec::new(); + + let tot_size: usize = str_in.iter().map(|s| s.len()).sum(); + if tot_size < FSST_SAMPLETARGET { + return str_in.clone(); + } + + let mut sample_rnd = fsst_hash(4637947); + let sample_lim = FSST_SAMPLETARGET; + let mut sample_buf_offset: usize = 0; + + while sample_buf_offset < sample_lim { + sample_rnd = fsst_hash(sample_rnd); + let mut line_nr = (sample_rnd as usize) % str_in.len(); + + // Find the first non-empty chunk starting at line_nr, wrapping around if + // necessary. + // + // TODO: this will loop infinitely if there are no non-empty lines in the sample + while str_in[line_nr].is_empty() { + if line_nr == str_in.len() { + line_nr = 0; + } + } + + let line = str_in[line_nr]; + let chunks = 1 + ((line.len() - 1) / FSST_SAMPLELINE); + sample_rnd = fsst_hash(sample_rnd); + let chunk = FSST_SAMPLELINE * ((sample_rnd as usize) % chunks); + + let len = FSST_SAMPLELINE.min(line.len() - chunk); + + sample_buf.extend_from_slice(&str_in[line_nr][chunk..chunk + len]); + + // SAFETY: this is the data we just placed into `sample_buf` in the line above. + let slice = + unsafe { std::slice::from_raw_parts(sample_buf.as_ptr().add(sample_buf_offset), len) }; + + sample.push(slice); + + sample_buf_offset += len; + } + + sample +} + +/// Hash function used in various components of the library. +/// +/// This is equivalent to the FSST_HASH macro from the C++ implementation. +#[inline] +pub(crate) fn fsst_hash(value: u64) -> u64 { + (value * 2971215073) ^ (value >> 15) +} + +impl Compressor { /// Build and train a `Compressor` from a sample corpus of text. /// /// This function implements the generational algorithm described in the [FSST paper] Section @@ -240,76 +603,167 @@ impl Compressor { /// code). /// /// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf - pub fn train(corpus: impl AsRef<[u8]>) -> Self { - let mut compressor = Compressor::default(); - // TODO(aduffy): handle truncating/sampling if corpus > requires sample size. - let sample = corpus.as_ref(); - if sample.is_empty() { - return compressor; - } + pub fn train(values: &Vec<&[u8]>) -> Self { + let mut builder = CompressorBuilder::new(); - let mut counter = Counter::new(); - for _generation in 0..(MAX_GENERATIONS - 1) { - compressor.compress_count(sample, &mut counter); - compressor.optimize(&counter, true); - counter.clear(); + if values.is_empty() { + return builder.build(); } - compressor.compress_count(sample, &mut counter); - compressor.optimize(&counter, true); + let mut counters = Counter::new(); + let mut sample_memory = Vec::with_capacity(FSST_SAMPLEMAX); + let sample = make_sample(&mut sample_memory, values); + for sample_frac in GENERATIONS { + for (i, line) in sample.iter().enumerate() { + if sample_frac < 128 && ((fsst_hash(i as u64) & 127) as usize) > sample_frac { + continue; + } + + builder.compress_count(line, &mut counters); + } + + builder.optimize(&counters, sample_frac); + counters.clear(); + } - compressor + builder.build() } } -impl Compressor { - /// Compress the text using the current symbol table. Count the code occurrences - /// and code-pair occurrences to allow us to calculate apparent gain. - fn compress_count(&self, sample: &[u8], counter: &mut Counter) { - let compressed = self.compress(sample); - let len = compressed.len(); +impl CompressorBuilder { + /// Find the longest symbol using the hash table and the codes_one_byte and codes_two_byte indexes. + fn find_longest_symbol(&self, word: u64) -> Code { + // Probe the hash table first to see if we have a long match + let entry = self.lossy_pht.lookup(word); + let ignored_bits = entry.ignored_bits; - if len == 0 { - return; + // If the entry is valid, return the code + if !entry.is_unused() && compare_masked(word, entry.symbol.as_u64(), ignored_bits) { + return entry.code; } - fn next_code(pos: usize, compressed: &[u8]) -> (u16, usize) { - if compressed[pos] == ESCAPE_CODE { - (compressed[pos + 1] as u16, 2) - } else { - (256 + compressed[pos] as u16, 1) + // Try and match first two bytes + let twobyte = self.codes_two_byte[word as u16 as usize]; + if twobyte.extended_code() >= FSST_CODE_BASE { + return twobyte; + } + + // Fall back to single-byte match + self.codes_one_byte[word as u8 as usize] + } + + /// Compress the text using the current symbol table. Count the code occurrences + /// and code-pair occurrences, calculating total gain using the current compressor. + /// + /// NOTE: this is largely an unfortunate amount of copy-paste from `compress`, just to make sure + /// we can do all the counting in a single pass. + fn compress_count(&self, sample: &[u8], counter: &mut Counter) -> usize { + let mut gain = 0; + if sample.is_empty() { + return gain; + } + + let mut in_ptr = sample.as_ptr(); + + // SAFETY: `end` will point just after the end of the `plaintext` slice. + let in_end = unsafe { in_ptr.byte_add(sample.len()) }; + let in_end_sub8 = in_end as usize - 8; + + let mut prev_code: u16 = FSST_CODE_MASK; + + while (in_ptr as usize) < (in_end_sub8) { + // SAFETY: ensured in-bounds by loop condition. + let word: u64 = unsafe { std::ptr::read_unaligned(in_ptr as *const u64) }; + let code = self.find_longest_symbol(word); + let code_u16 = code.extended_code(); + + // Gain increases by the symbol length if a symbol matches, or 0 + // if an escape is emitted. + gain += (code.len() as usize) - ((code_u16 < 256) as usize); + + // Record the single and pair counts + counter.record_count1(code_u16); + counter.record_count2(prev_code, code_u16); + + // Also record the count for just extending by a single byte, but only if + // the symbol is not itself a single byte. + if code.len() > 1 { + let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16; + counter.record_count1(code_first_byte); + counter.record_count2(prev_code, code_first_byte); } + + // SAFETY: pointer bound is checked in loop condition before any access is made. + in_ptr = unsafe { in_ptr.byte_add(code.len() as usize) }; + + prev_code = code_u16; + } + + let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) }; + debug_assert!( + remaining_bytes.is_positive(), + "in_ptr exceeded in_end, should not be possible" + ); + let remaining_bytes = remaining_bytes as usize; + + // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above, + // but shift data out of this word rather than advancing an input pointer and potentially reading + // unowned memory + let mut bytes = [0u8; 8]; + unsafe { + // SAFETY: it is safe to read up to remaining_bytes from in_ptr, and remaining_bytes + // will be <= 8 bytes. + std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes); } + let mut last_word = u64::from_le_bytes(bytes); + + let mut remaining_bytes = remaining_bytes; - // Get first code, record count - let (code, pos) = next_code(0, &compressed); - counter.record_count1(code); + while remaining_bytes > 0 { + // SAFETY: ensured in-bounds by loop condition. + let code = self.find_longest_symbol(last_word); + let code_u16 = code.extended_code(); - let mut pos = pos; - let mut prev_code = code; + // Gain increases by the symbol length if a symbol matches, or 0 + // if an escape is emitted. + gain += (code.len() as usize) - ((code_u16 < 256) as usize); - while pos < len { - let (code, advance) = next_code(pos, &compressed); - pos += advance; + // Record the single and pair counts + counter.record_count1(code_u16); + counter.record_count2(prev_code, code_u16); - counter.record_count1(code); - counter.record_count2(prev_code, code); + // Also record the count for just extending by a single byte, but only if + // the symbol is not itself a single byte. + if code.len() > 1 { + let code_first_byte = self.symbols[code_u16 as usize].first_byte() as u16; + counter.record_count1(code_first_byte); + counter.record_count2(prev_code, code_first_byte); + } + + // Advance our last_word "input pointer" by shifting off the covered values. + let advance = code.len() as usize; + remaining_bytes -= advance; + last_word = advance_8byte_word(last_word, advance); - prev_code = code; + prev_code = code_u16; } + + gain } /// Using a set of counters and the existing set of symbols, build a new /// set of symbols/codes that optimizes the gain over the distribution in `counter`. - fn optimize(&mut self, counters: &Counter, include_ascii: bool) { + fn optimize(&mut self, counters: &Counter, sample_frac: usize) { let mut pqueue = BinaryHeap::with_capacity(65_536); for code1 in counters.first_codes() { let symbol1 = self.symbols[code1 as usize]; let symbol1_len = symbol1.len(); let count = counters.count1(code1); - // If count is zero, we can skip the whole inner loop. - if count == 0 { + + // From the c++ impl: + // "improves both compression speed (less candidates), but also quality!!" + if count < (5 * sample_frac / 128) { continue; } @@ -319,15 +773,19 @@ impl Compressor { if code1 < 256 { gain *= 8; } - if gain > 0 { - pqueue.push(Candidate { - symbol: symbol1, - gain, - }); + + pqueue.push(Candidate { + symbol: symbol1, + gain, + }); + + // Skip merges on last round, or when symbol cannot be extended. + if sample_frac >= 128 || symbol1_len == 8 { + continue; } for code2 in counters.second_codes(code1) { - let symbol2 = &self.symbols[code2 as usize]; + let symbol2 = self.symbols[code2 as usize]; // If merging would yield a symbol of length greater than 8, skip. if symbol1_len + symbol2.len() > 8 { @@ -335,12 +793,11 @@ impl Compressor { } let new_symbol = symbol1.concat(symbol2); let gain = counters.count2(code1, code2) * new_symbol.len(); - if gain > 0 { - pqueue.push(Candidate { - symbol: new_symbol, - gain, - }) - } + + pqueue.push(Candidate { + symbol: new_symbol, + gain, + }) } } @@ -351,29 +808,10 @@ impl Compressor { let mut n_symbols = 0; while !pqueue.is_empty() && n_symbols < 255 { let candidate = pqueue.pop().unwrap(); - if self.insert(candidate.symbol) { + if self.insert(candidate.symbol, candidate.symbol.len()) { n_symbols += 1; } } - - // If there are leftover slots, fill them with ASCII chars. - // This helps reduce the number of escapes. - // - // Note that because of the lossy hash table, we won't accidentally - // save the same ASCII character twice into the table. - if include_ascii { - for character in - " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[](){}:?/<>".bytes() - { - if n_symbols == 255 { - break; - } - - if self.insert(Symbol::from_u8(character)) { - n_symbols += 1 - } - } - } } } @@ -422,17 +860,21 @@ mod test { #[test] fn test_builder() { // Train a Compressor on the toy string - let text = "hello world"; - let table = Compressor::train(text.as_bytes()); + let text = b"hello hello hello hello hello"; + + // count of 5 is the cutoff for including a symbol in the table. + let table = Compressor::train(&vec![text, text, text, text, text]); // Use the table to compress a string, see the values - let compressed = table.compress(text.as_bytes()); + let compressed = table.compress(text); // Ensure that the compressed string has no escape bytes assert!(compressed.iter().all(|b| *b != ESCAPE_CODE)); // Ensure that we can compress a string with no values seen at training time, with escape bytes let compressed = table.compress("xyz123".as_bytes()); + let decompressed = table.decompressor().decompress(&compressed); + assert_eq!(&decompressed, b"xyz123"); assert_eq!( compressed, vec![ @@ -481,7 +923,7 @@ mod test { } assert_eq!( map.codes().collect::>(), - (0u16..512u16).collect::>() + (0u16..511u16).collect::>() ); } diff --git a/src/lib.rs b/src/lib.rs index cf33a4b..4f00b47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,20 +8,18 @@ macro_rules! assert_sizeof { }; } -use std::fmt::{Debug, Formatter}; - use lossy_pht::LossyPHT; +use std::fmt::{Debug, Formatter}; mod builder; mod lossy_pht; +pub use builder::*; + /// `Symbol`s are small (up to 8-byte) segments of strings, stored in a [`Compressor`][`crate::Compressor`] and /// identified by an 8-bit code. #[derive(Copy, Clone)] -pub union Symbol { - bytes: [u8; 8], - num: u64, -} +pub struct Symbol(u64); assert_sizeof!(Symbol => 8); @@ -31,17 +29,26 @@ impl Symbol { /// Constructor for a `Symbol` from an 8-element byte slice. pub fn from_slice(slice: &[u8; 8]) -> Self { - Self { bytes: *slice } + let num: u64 = slice[0] as u64 + | (slice[1] as u64) << 8 + | (slice[2] as u64) << 16 + | (slice[3] as u64) << 24 + | (slice[4] as u64) << 32 + | (slice[5] as u64) << 40 + | (slice[6] as u64) << 48 + | (slice[7] as u64) << 56; + + Self(num) } /// Return a zero symbol const fn zero() -> Self { - Self { num: 0 } + Self(0) } /// Create a new single-byte symbol pub fn from_u8(value: u8) -> Self { - Self { num: value as u64 } + Self(value as u64) } } @@ -53,8 +60,8 @@ impl Symbol { /// that holds the byte 0x00. In that case, the symbol contains `0x0000000000000000` /// but we want to interpret that as a one-byte symbol containing `0x00`. #[allow(clippy::len_without_is_empty)] - pub fn len(&self) -> usize { - let numeric = unsafe { self.num }; + pub fn len(self) -> usize { + let numeric = self.0; // For little-endian platforms, this counts the number of *trailing* zeros let null_bytes = (numeric.leading_zeros() >> 3) as usize; @@ -69,88 +76,72 @@ impl Symbol { } #[inline] - fn as_u64(&self) -> u64 { - // SAFETY: the bytes can always be viewed as a u64 - unsafe { self.num } + fn as_u64(self) -> u64 { + self.0 } /// Get the first byte of the symbol as a `u8`. /// /// If the symbol is empty, this will return the zero byte. #[inline] - pub fn first_byte(&self) -> u8 { - // SAFETY: the bytes can always be viewed as a u64 - unsafe { self.num as u8 } + pub fn first_byte(self) -> u8 { + self.0 as u8 } /// Get the first two bytes of the symbol as a `u16`. /// /// If the Symbol is one or zero bytes, this will return `0u16`. #[inline] - pub fn first_two_bytes(&self) -> u16 { - // SAFETY: the bytes can always be viewed as a u64 - unsafe { self.num as u16 } - } - - /// Access the Symbol as a slice. - pub fn as_slice(&self) -> &[u8] { - let len = self.len(); - // SAFETY: constructors will not allow building a struct where len > 8. - unsafe { &self.bytes[0..len] } + pub fn first2(self) -> u16 { + self.0 as u16 } - /// Returns true if the symbol is a prefix of the provided text. - pub fn is_prefix(&self, text: &[u8]) -> bool { - text.starts_with(self.as_slice()) + /// Get the first two bytes of the symbol as a `u16`. + /// + /// If the Symbol is one or zero bytes, this will return `0u16`. + #[inline] + pub fn first3(self) -> u64 { + self.0 & 0xFF_FF_FF } /// Return a new `Symbol` by logically concatenating ourselves with another `Symbol`. - pub fn concat(&self, other: &Self) -> Self { - let self_len = self.len(); - let new_len = self_len + other.len(); - assert!(new_len <= 8, "cannot build symbol with length > 8"); - - // SAFETY: we assert the combined length <= 8 - unsafe { - Self { - num: (other.num << (8 * self_len)) | self.num, - } - } - } -} + pub fn concat(self, other: Self) -> Self { + debug_assert!( + self.len() + other.len() <= 8, + "cannot build symbol with length > 8" + ); -#[cfg(test)] -mod test { - use crate::Symbol; + let self_len = self.len(); - #[test] - fn test_concat() { - let symbola = Symbol::from_u8(b'a'); - let symbolb = Symbol::from_u8(b'b'); - let symbolab = symbola.concat(&symbolb); - assert_eq!(symbolab.as_slice(), b"ab"); + Self((other.0 << (8 * self_len)) | self.0) } } impl Debug for Symbol { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let debug = self - .as_slice() - .iter() - .map(|c| *c as char) - .map(|c| { - if c.is_ascii() { - format!("{c}") - } else { - format!("{c:X?}") - } - }) - .collect::>(); - write!(f, "{:?}", debug) + write!(f, "[")?; + + let slice = &self.0.to_le_bytes()[0..self.len()]; + for c in slice.iter().map(|c| *c as char) { + if ('!'..='~').contains(&c) { + write!(f, "{c}")?; + } else if c == '\n' { + write!(f, " \\n ")?; + } else if c == '\t' { + write!(f, " \\t ")?; + } else if c == ' ' { + write!(f, " SPACE ")?; + } else { + write!(f, " 0x{:X?} ", c as u8)? + } + } + + write!(f, "]") } } -/// Code and associated metadata fro a symbol. +/// A packed type containing a code value, as well as metadata about the symbol referred to by +/// the code. /// /// Logically, codes can range from 0-255 inclusive. This type holds both the 8-bit code as well as /// other metadata bit-packed into a `u16`. @@ -164,7 +155,7 @@ impl Debug for Symbol { /// /// Bits 12-15 store the length of the symbol (values ranging from 0-8). #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct CodeMeta(u16); +struct Code(u16); /// Code used to indicate bytes that are not in the symbol table. /// @@ -174,46 +165,69 @@ struct CodeMeta(u16); /// being looked up in the symbol table. pub const ESCAPE_CODE: u8 = 255; +/// Number of bits in the `ExtendedCode` that are used to dictate a code value. +pub const FSST_CODE_BITS: usize = 9; + +/// First bit of the "length" portion of an extended code. +pub const FSST_LEN_BITS: usize = 12; + +/// A code that never appears in practice, indicating an unused slot. +pub const FSST_CODE_UNUSED: u16 = 1u16 << FSST_CODE_BITS; + +/// Maximum code value in the extended code range. +pub const FSST_CODE_MAX: u16 = 1 << FSST_CODE_BITS; + /// Maximum value for the extended code range. /// /// When truncated to u8 this is code 255, which is equivalent to [`ESCAPE_CODE`]. -pub const MAX_CODE: u16 = 511; +pub const FSST_CODE_MASK: u16 = FSST_CODE_MAX - 1; + +/// First code in the symbol table that corresponds to a non-escape symbol. +pub const FSST_CODE_BASE: u16 = 256; #[allow(clippy::len_without_is_empty)] -impl CodeMeta { - const EMPTY: Self = CodeMeta(MAX_CODE); +impl Code { + /// Code for an unused slot in a symbol table or index. + /// + /// This corresponds to the maximum code with a length of 1. + pub const UNUSED: Self = Code(FSST_CODE_MASK + (1 << 12)); - fn new(code: u8, escape: bool, len: u16) -> Self { - let value = (len << 12) | ((escape as u16) << 8) | (code as u16); - Self(value) + /// Create a new code for a symbol of given length. + fn new_symbol(code: u8, len: usize) -> Self { + Self(code as u16 + ((len as u16) << FSST_LEN_BITS)) } - /// Create a new code from a [`Symbol`]. - fn new_symbol(code: u8, symbol: Symbol) -> Self { - assert_ne!(code, ESCAPE_CODE, "ESCAPE_CODE cannot be used for symbol"); + /// Code for a new symbol during the building phase. + /// + /// The code is remapped from 0..254 to 256...510. + fn new_symbol_building(code: u8, len: usize) -> Self { + Self(code as u16 + 256 + ((len as u16) << FSST_LEN_BITS)) + } - Self::new(code, false, symbol.len() as u16) + /// Create a new code corresponding for an escaped byte. + fn new_escape(byte: u8) -> Self { + Self((byte as u16) + (1 << FSST_LEN_BITS)) } #[inline] - fn code(&self) -> u8 { + fn code(self) -> u8 { self.0 as u8 } #[inline] - fn extended_code(&self) -> u16 { + fn extended_code(self) -> u16 { self.0 & 0b111_111_111 } #[inline] - fn len(&self) -> u16 { - self.0 >> 12 + fn len(self) -> u16 { + self.0 >> FSST_LEN_BITS } } -impl Debug for CodeMeta { +impl Debug for Code { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CodeMeta") + f.debug_struct("TrainingCode") .field("code", &(self.0 as u8)) .field("is_escape", &(self.0 < 256)) .field("len", &(self.0 >> 12)) @@ -224,13 +238,11 @@ impl Debug for CodeMeta { /// Decompressor uses a symbol table to take a stream of 8-bit codes into a string. #[derive(Clone)] pub struct Decompressor<'a> { - /// Table mapping codes to symbols. - /// - /// The first 256 slots are escapes. The following slots (up to 254) - /// are for symbols with actual codes. - /// - /// This physical layout is important so that we can do straight-line execution in the decompress method. + /// Slice mapping codes to symbols. pub(crate) symbols: &'a [Symbol], + + /// Slice containing the length of each symbol in the `symbols` slice. + pub(crate) lengths: &'a [u8], } impl<'a> Decompressor<'a> { @@ -238,14 +250,14 @@ impl<'a> Decompressor<'a> { /// /// # Panics /// - /// If the provided symbol table has length greater than [`MAX_CODE`]. - pub fn new(symbols: &'a [Symbol]) -> Self { + /// If the provided symbol table has length greater than 256 + pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self { assert!( - symbols.len() <= MAX_CODE as usize, - "symbol table cannot have size exceeding MAX_CODE" + symbols.len() <= 255, + "symbol table cannot have size exceeding 255" ); - Self { symbols } + Self { symbols, lengths } } /// Decompress a byte slice that was previously returned by a compressor using @@ -265,24 +277,25 @@ impl<'a> Decompressor<'a> { // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer unsafe { let write_addr = ptr.byte_offset(out_pos as isize); - write_addr.write(compressed[in_pos]); + std::ptr::write(write_addr, compressed[in_pos]); } out_pos += 1; in_pos += 1; } else { - let symbol = self.symbols[256 + code as usize]; + let symbol = self.symbols[code as usize]; + let length = self.lengths[code as usize]; // SAFETY: out_pos is always 8 bytes or more from the end of decoded buffer unsafe { let write_addr = ptr.byte_offset(out_pos as isize) as *mut u64; // Perform 8 byte unaligned write. - write_addr.write_unaligned(symbol.num); + write_addr.write_unaligned(symbol.as_u64()); } in_pos += 1; - out_pos += symbol.len(); + out_pos += length as usize; } } - assert!( + debug_assert!( in_pos >= compressed.len(), "decompression should exhaust input before output" ); @@ -302,11 +315,12 @@ impl<'a> Decompressor<'a> { /// Example usage: /// /// ``` -/// use fsst::{Symbol, Compressor}; -/// let mut compressor = Compressor::default(); -/// -/// // Insert a new symbol -/// assert!(compressor.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]))); +/// use fsst::{Symbol, Compressor, CompressorBuilder}; +/// let compressor = { +/// let mut builder = CompressorBuilder::new(); +/// builder.insert(Symbol::from_slice(&[b'h', b'e', b'l', b'l', b'o', 0, 0, 0]), 5); +/// builder.build() +/// }; /// /// let compressed = compressor.compress("hello".as_bytes()); /// assert_eq!(compressed, vec![0u8]); @@ -316,72 +330,28 @@ pub struct Compressor { /// Table mapping codes to symbols. pub(crate) symbols: Vec, + /// Length of each symbol, values range from 1-8. + pub(crate) lengths: Vec, + /// The number of entries in the symbol table that have been populated, not counting /// the escape values. pub(crate) n_symbols: u8, /// Inverted index mapping 2-byte symbols to codes - codes_twobyte: Vec, + codes_two_byte: Vec, + + /// Limit of no suffixes. + has_suffix_code: u8, /// Lossy perfect hash table for looking up codes to symbols that are 3 bytes or more lossy_pht: LossyPHT, } -impl Default for Compressor { - fn default() -> Self { - // NOTE: `vec!` has a specialization for building a new vector of `0u64`. Because Symbol and u64 - // have the same bit pattern, we can allocate as u64 and transmute. If we do `vec![Symbol::EMPTY; N]`, - // that will create a new Vec and call `Symbol::EMPTY.clone()` `N` times which is considerably slower. - let symbols = vec![0u64; 511]; - // SAFETY: transmute safety assured by the compiler. - let symbols: Vec = unsafe { std::mem::transmute(symbols) }; - let mut table = Self { - symbols, - n_symbols: 0, - codes_twobyte: vec![CodeMeta::EMPTY; 65_536], - lossy_pht: LossyPHT::new(), - }; - - // Populate the escape byte entries. - for byte in 0..=255 { - table.symbols[byte as usize] = Symbol::from_u8(byte); - } - - table - } -} - /// The core structure of the FSST codec, holding a mapping between `Symbol`s and `Code`s. /// /// The symbol table is trained on a corpus of data in the form of a single byte array, building up /// a mapping of 1-byte "codes" to sequences of up to `N` plaintext bytse, or "symbols". impl Compressor { - /// Attempt to insert a new symbol at the end of the table. - /// - /// # Panics - /// Panics if the table is already full. - pub fn insert(&mut self, symbol: Symbol) -> bool { - assert!(self.n_symbols < 255, "cannot insert into full symbol table"); - - let symbol_len = symbol.len(); - if symbol_len <= 2 { - // Insert the 2-byte symbol into the twobyte cache - self.codes_twobyte[symbol.first_two_bytes() as usize] = - CodeMeta::new_symbol(self.n_symbols, symbol); - } else { - // Attempt to insert larger symbols into the 3-byte cache - if !self.lossy_pht.insert(symbol, self.n_symbols) { - return false; - } - } - - // Insert at the end of the symbols table. - // Note the rescaling from range [0-254] -> [256, 510]. - self.symbols[256 + (self.n_symbols as usize)] = symbol; - self.n_symbols += 1; - true - } - /// Using the symbol table, runs a single cycle of compression on an input word, writing /// the output into `out_ptr`. /// @@ -397,57 +367,98 @@ impl Compressor { /// # Safety /// /// `out_ptr` must never be NULL or otherwise point to invalid memory. - #[inline] + #[inline(never)] pub unsafe fn compress_word(&self, word: u64, out_ptr: *mut u8) -> (usize, usize) { // Speculatively write the first byte of `word` at offset 1. This is necessary if it is an escape, and // if it isn't, it will be overwritten anyway. // // SAFETY: caller ensures out_ptr is not null let first_byte = word as u8; - unsafe { out_ptr.byte_add(1).write_unaligned(first_byte) }; + out_ptr.byte_add(1).write_unaligned(first_byte); - // Probe the hash table - let entry = self.lossy_pht.lookup(word); + // First, check the two_bytes table + let code_twobyte = self.codes_two_byte[word as u16 as usize]; - // Now, downshift the `word` and the `entry` to see if they align. - let ignored_bits = entry.ignored_bits; + if code_twobyte.code() < self.has_suffix_code { + // 2 byte code without having to worry about longer matches. + std::ptr::write(out_ptr, code_twobyte.code()); - if !compare_masked(word, entry.symbol.as_u64(), ignored_bits) || entry.is_unused() { - // lookup the appropriate code for the twobyte sequence and write it - // This will hold either 511, OR it will hold the actual code. - let code = self.codes_twobyte[(word as u16) as usize]; - let out = code.code(); - unsafe { - out_ptr.write(out); + // Advance input by symbol length (2) and output by a single code byte + (2, 1) + } else { + // Probe the hash table + let entry = self.lossy_pht.lookup(word); + + // Now, downshift the `word` and the `entry` to see if they align. + let ignored_bits = entry.ignored_bits; + if entry.code != Code::UNUSED + && compare_masked(word, entry.symbol.as_u64(), ignored_bits) + { + // Advance the input by the symbol length (variable) and the output by one code byte + std::ptr::write(out_ptr, entry.code.code()); + (entry.code.len() as usize, 1) + } else { + std::ptr::write(out_ptr, code_twobyte.code()); + + // Advance the input by the symbol length (variable) and the output by either 1 + // byte (if was one-byte code) or two bytes (escape). + ( + code_twobyte.len() as usize, + // Predicated version of: + // + // if entry.code >= 256 { + // 2 + // } else { + // 1 + // } + 1 + (code_twobyte.extended_code() >> 8) as usize, + ) } - - // Advance the input by one byte and the output by 1 byte (if real code) or 2 bytes (if escape). - return ( - if out == ESCAPE_CODE { - 1 - } else { - code.len() as usize - }, - if out == ESCAPE_CODE { 2 } else { 1 }, - ); - } - - let code = entry.code; - unsafe { - out_ptr.write_unaligned(code.code()); } - - (code.len() as usize, 1) } - /// Use the symbol table to compress the plaintext into a sequence of codes and escapes. - pub fn compress(&self, plaintext: &[u8]) -> Vec { - if plaintext.is_empty() { - return Vec::new(); + /// Compress many lines in bulk. + pub fn compress_bulk(&self, lines: &Vec<&[u8]>) -> Vec> { + let mut res = Vec::new(); + + for line in lines { + res.push(self.compress(line)); } - let mut values: Vec = Vec::with_capacity(2 * plaintext.len()); + res + } + /// Compress a string, writing its result into a target buffer. + /// + /// The target buffer is a byte vector that must have capacity large enough + /// to hold the encoded data. + /// + /// When this call returns, `values` will hold the compressed bytes and have + /// its length set to the length of the compresed text. + /// + /// ``` + /// use fsst::{Compressor, CompressorBuilder, Symbol}; + /// + /// let mut compressor = CompressorBuilder::new(); + /// assert!(compressor.insert(Symbol::from_slice(b"aaaaaaaa"), 8)); + /// + /// let compressor = compressor.build(); + /// + /// let mut compressed_values = Vec::with_capacity(1_024); + /// + /// // SAFETY: we have over-sized compressed_values. + /// unsafe { + /// compressor.compress_into(b"aaaaaaaa", &mut compressed_values); + /// } + /// + /// assert_eq!(compressed_values, vec![0u8]); + /// ``` + /// + /// # Safety + /// + /// It is up to the caller to ensure the provided buffer is large enough to hold + /// all encoded data. + pub unsafe fn compress_into(&self, plaintext: &[u8], values: &mut Vec) { let mut in_ptr = plaintext.as_ptr(); let mut out_ptr = values.as_mut_ptr(); @@ -457,12 +468,12 @@ impl Compressor { // SAFETY: `end` will point just after the end of the `values` allocation. let out_end = unsafe { out_ptr.byte_add(values.capacity()) }; - while (in_ptr as usize) < in_end_sub8 && out_ptr < out_end { + while (in_ptr as usize) <= in_end_sub8 && out_ptr < out_end { // SAFETY: pointer ranges are checked in the loop condition unsafe { // Load a full 8-byte word of data from in_ptr. // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though. - let word: u64 = (in_ptr as *const u64).read_unaligned(); + let word: u64 = std::ptr::read_unaligned(in_ptr as *const u64); let (advance_in, advance_out) = self.compress_word(word, out_ptr); in_ptr = in_ptr.byte_add(advance_in); out_ptr = out_ptr.byte_add(advance_out); @@ -471,39 +482,27 @@ impl Compressor { let remaining_bytes = unsafe { in_end.byte_offset_from(in_ptr) }; assert!( - remaining_bytes.is_positive(), - "in_ptr exceeded in_end, should not be possible" + out_ptr < out_end || remaining_bytes == 0, + "output buffer sized too small" ); + let remaining_bytes = remaining_bytes as usize; // Load the last `remaining_byte`s of data into a final world. We then replicate the loop above, // but shift data out of this word rather than advancing an input pointer and potentially reading // unowned memory. - let mut last_word = unsafe { - match remaining_bytes { - 0 => 0, - 1 => extract_u64::<1>(in_ptr), - 2 => extract_u64::<2>(in_ptr), - 3 => extract_u64::<3>(in_ptr), - 4 => extract_u64::<4>(in_ptr), - 5 => extract_u64::<5>(in_ptr), - 6 => extract_u64::<6>(in_ptr), - 7 => extract_u64::<7>(in_ptr), - 8 => extract_u64::<8>(in_ptr), - _ => unreachable!("remaining bytes must be <= 8"), - } - }; + let mut bytes = [0u8; 8]; + std::ptr::copy_nonoverlapping(in_ptr, bytes.as_mut_ptr(), remaining_bytes); + let mut last_word = u64::from_le_bytes(bytes); while in_ptr < in_end && out_ptr < out_end { - unsafe { - // Load a full 8-byte word of data from in_ptr. - // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though. - let (advance_in, advance_out) = self.compress_word(last_word, out_ptr); - in_ptr = in_ptr.byte_add(advance_in); - out_ptr = out_ptr.byte_add(advance_out); + // Load a full 8-byte word of data from in_ptr. + // SAFETY: caller asserts in_ptr is not null. we may read past end of pointer though. + let (advance_in, advance_out) = self.compress_word(last_word, out_ptr); + in_ptr = in_ptr.byte_add(advance_in); + out_ptr = out_ptr.byte_add(advance_out); - last_word = advance_8byte_word(last_word, advance_in); - } + last_word = advance_8byte_word(last_word, advance_in); } // in_ptr should have exceeded in_end @@ -511,36 +510,54 @@ impl Compressor { // Count the number of bytes written // SAFETY: assertion - unsafe { - let bytes_written = out_ptr.offset_from(values.as_ptr()); - assert!( - bytes_written.is_positive(), - "out_ptr ended before it started, not possible" - ); - - values.set_len(bytes_written as usize); + let bytes_written = out_ptr.offset_from(values.as_ptr()); + assert!( + bytes_written.is_positive(), + "out_ptr ended before it started, not possible" + ); + + values.set_len(bytes_written as usize); + } + + /// Use the symbol table to compress the plaintext into a sequence of codes and escapes. + pub fn compress(&self, plaintext: &[u8]) -> Vec { + if plaintext.is_empty() { + return Vec::new(); } - values + let mut buffer = Vec::with_capacity(plaintext.len() * 2); + + // SAFETY: the largest compressed size would be all escapes == 2*plaintext_len + unsafe { self.compress_into(plaintext, &mut buffer) }; + + buffer } /// Access the decompressor that can be used to decompress strings emitted from this /// `Compressor` instance. pub fn decompressor(&self) -> Decompressor { - Decompressor::new(self.symbol_table()) + Decompressor::new(self.symbol_table(), self.symbol_lengths()) } /// Returns a readonly slice of the current symbol table. /// - /// The returned slice will have length of `256 + n_symbols`. + /// The returned slice will have length of `n_symbols`. pub fn symbol_table(&self) -> &[Symbol] { - &self.symbols[0..(256 + self.n_symbols as usize)] + &self.symbols[0..self.n_symbols as usize] + } + + /// Returns a readonly slice where index `i` contains the + /// length of the symbol represented by code `i`. + /// + /// Values range from 1-8. + pub fn symbol_lengths(&self) -> &[u8] { + &self.lengths[0..self.n_symbols as usize] } } #[inline] -fn advance_8byte_word(word: u64, bytes: usize) -> u64 { - // shift the word off the right-end, because little endian means the first +pub(crate) fn advance_8byte_word(word: u64, bytes: usize) -> u64 { + // shift the word off the low-end, because little endian means the first // char is stored in the LSB. // // Note that even though this looks like it branches, Rust compiles this to a @@ -553,47 +570,7 @@ fn advance_8byte_word(word: u64, bytes: usize) -> u64 { } #[inline] -fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool { - let mask = if ignored_bits == 64 { - 0 - } else { - u64::MAX >> ignored_bits - }; - +pub(crate) fn compare_masked(left: u64, right: u64, ignored_bits: u16) -> bool { + let mask = u64::MAX >> ignored_bits; (left & mask) == right } - -/// This is a function that will get monomorphized based on the value of `N` to do -/// a load of `N` values from the pointer in a minimum number of instructions into -/// an output `u64`. -#[inline] -unsafe fn extract_u64(ptr: *const u8) -> u64 { - match N { - 1 => ptr.read() as u64, - 2 => (ptr as *const u16).read_unaligned() as u64, - 3 => { - let low = ptr.read() as u64; - let high = (ptr.byte_add(1) as *const u16).read_unaligned() as u64; - high << 8 | low - } - 4 => (ptr as *const u32).read_unaligned() as u64, - 5 => { - let low = (ptr as *const u32).read_unaligned() as u64; - let high = ptr.byte_add(4).read() as u64; - high << 32 | low - } - 6 => { - let low = (ptr as *const u32).read_unaligned() as u64; - let high = (ptr.byte_add(4) as *const u16).read_unaligned() as u64; - high << 32 | low - } - 7 => { - let low = (ptr as *const u32).read_unaligned() as u64; - let mid = (ptr.byte_add(4) as *const u16).read_unaligned() as u64; - let high = ptr.byte_add(6).read() as u64; - (high << 48) | (mid << 32) | low - } - 8 => (ptr as *const u64).read_unaligned(), - _ => unreachable!("N must be <= 8"), - } -} diff --git a/src/lossy_pht.rs b/src/lossy_pht.rs index 1d41243..9570d70 100644 --- a/src/lossy_pht.rs +++ b/src/lossy_pht.rs @@ -1,28 +1,32 @@ +// TODO: remove +#![allow(unused)] + use std::fmt::Debug; -use crate::CodeMeta; +use crate::builder::fsst_hash; use crate::Symbol; -use crate::MAX_CODE; +use crate::FSST_CODE_MASK; +use crate::{Code, FSST_CODE_UNUSED}; /// Size of the perfect hash table. /// /// NOTE: this differs from the paper, which recommends a 64KB total /// table size. The paper does not account for the fact that most /// vendors split the L1 cache into 32KB of instruction and 32KB of data. -pub const HASH_TABLE_SIZE: usize = 1 << 11; +pub const HASH_TABLE_SIZE: usize = 1 << 12; /// A single entry in the [Lossy Perfect Hash Table][`LossyPHT`]. /// /// `TableEntry` is based on the `Symbol` class outlined in Algorithm 4 of the FSST paper. See /// the module documentation for a link to the paper. -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Debug)] #[repr(C)] pub(crate) struct TableEntry { /// Symbol, piece of a string, 8 bytes or fewer. pub(crate) symbol: Symbol, /// Code and associated metadata for the symbol - pub(crate) code: CodeMeta, + pub(crate) code: Code, /// Number of ignored bits in `symbol`. /// @@ -35,8 +39,7 @@ assert_sizeof!(TableEntry => 16); impl TableEntry { pub(crate) fn is_unused(&self) -> bool { - // 511 should never come up for real, so use as the sentinel for an unused slot - self.code.extended_code() == MAX_CODE + self.code == Code::UNUSED } } @@ -63,7 +66,7 @@ impl LossyPHT { let slots = vec![ TableEntry { symbol: Symbol::ZERO, - code: CodeMeta::EMPTY, + code: Code::UNUSED, ignored_bits: 64, }; HASH_TABLE_SIZE @@ -79,43 +82,46 @@ impl LossyPHT { /// # Returns /// /// True if the symbol was inserted into the table, false if it was rejected due to collision. - pub(crate) fn insert(&mut self, symbol: Symbol, code: u8) -> bool { + pub(crate) fn insert(&mut self, symbol: Symbol, len: usize, code: u8) -> bool { let prefix_3bytes = symbol.as_u64() & 0xFF_FF_FF; - let slot = self.hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); + let slot = fsst_hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); let entry = &mut self.slots[slot]; - if !entry.is_unused() { false } else { entry.symbol = symbol; - entry.code = CodeMeta::new_symbol(code, symbol); + entry.code = Code::new_symbol_building(code, len); entry.ignored_bits = (64 - 8 * symbol.len()) as u16; true } } + /// Given a new code mapping, rewrite the codes into the new code range. + pub(crate) fn renumber(&mut self, new_codes: &[u8]) { + for slot in self.slots.iter_mut() { + if slot.code != Code::UNUSED { + let old_code = slot.code.code(); + let new_code = new_codes[old_code as usize]; + let len = slot.code.len(); + slot.code = Code::new_symbol(new_code, len as usize); + } + } + } + /// Remove the symbol from the hashtable, if it exists. pub(crate) fn remove(&mut self, symbol: Symbol) { let prefix_3bytes = symbol.as_u64() & 0xFF_FF_FF; - let slot = self.hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); - self.slots[slot].code = CodeMeta::EMPTY; + let slot = fsst_hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); + self.slots[slot].code = Code::UNUSED; } #[inline] - pub(crate) fn lookup(&self, word: u64) -> TableEntry { + pub(crate) fn lookup(&self, word: u64) -> &TableEntry { let prefix_3bytes = word & 0xFF_FF_FF; - let slot = self.hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); - - // SAFETY: the slot is guaranteed to between 0...(HASH_TABLE_SIZE - 1). - unsafe { *self.slots.get_unchecked(slot) } - } + let slot = fsst_hash(prefix_3bytes) as usize & (HASH_TABLE_SIZE - 1); - /// Hash a value to find the bucket it belongs in. - /// - /// The particular hash function comes from the code listing of Algorithm 4 of the FSST paper. - #[inline] - fn hash(&self, value: u64) -> u64 { - (value * 2971215073) ^ (value >> 15) + // SAFETY: the slot is guaranteed to between [0, HASH_TABLE_SIZE). + unsafe { self.slots.get_unchecked(slot) } } } diff --git a/tests/correctness.rs b/tests/correctness.rs index 64f3ba7..2b42e98 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -1,6 +1,6 @@ #![cfg(test)] -use fsst::{Compressor, Symbol}; +use fsst::{Compressor, CompressorBuilder, Symbol}; static PREAMBLE: &str = r#" When in the Course of human events, it becomes necessary for one people to dissolve @@ -16,7 +16,7 @@ static ART_OF_WAR: &str = include_str!("./fixtures/art_of_war.txt"); #[test] fn test_basic() { // Roundtrip the declaration - let trained = Compressor::train(PREAMBLE); + let trained = Compressor::train(&vec![PREAMBLE.as_bytes()]); let compressed = trained.compress(PREAMBLE.as_bytes()); let decompressed = trained.decompressor().decompress(&compressed); assert_eq!(decompressed, PREAMBLE.as_bytes()); @@ -24,7 +24,7 @@ fn test_basic() { #[test] fn test_train_on_empty() { - let trained = Compressor::train(""); + let trained = Compressor::train(&vec![]); // We can still compress with it, but the symbols are going to be empty. let compressed = trained.compress("the quick brown fox jumped over the lazy dog".as_bytes()); assert_eq!( @@ -35,9 +35,10 @@ fn test_train_on_empty() { #[test] fn test_one_byte() { - let mut empty = Compressor::default(); - // Assign code 0 to map to the symbol containing byte 0x01 - empty.insert(Symbol::from_u8(0x01)); + let mut empty = CompressorBuilder::new(); + empty.insert(Symbol::from_u8(0x01), 1); + + let empty = empty.build(); let compressed = empty.compress(&[0x01]); assert_eq!(compressed, vec![0u8]); @@ -48,7 +49,7 @@ fn test_one_byte() { #[test] fn test_zeros() { let training_data: Vec = vec![0, 1, 2, 3, 4, 0]; - let trained = Compressor::train(&training_data); + let trained = Compressor::train(&vec![&training_data]); let compressed = trained.compress(&[4, 0]); assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]); } @@ -58,7 +59,7 @@ fn test_zeros() { fn test_large() { let corpus: Vec = DECLARATION.bytes().cycle().take(10_240).collect(); - let trained = Compressor::train(&corpus); + let trained = Compressor::train(&vec![&corpus]); let massive: Vec = DECLARATION .bytes() .cycle() @@ -71,7 +72,7 @@ fn test_large() { #[test] fn test_chinese() { - let trained = Compressor::train(ART_OF_WAR.as_bytes()); + let trained = Compressor::train(&vec![ART_OF_WAR.as_bytes()]); assert_eq!( ART_OF_WAR.as_bytes(), trained