Skip to content

Commit

Permalink
feat: make Compressor::train 2x faster with bitmap index (#16)
Browse files Browse the repository at this point in the history
The slowest part of Compressor::train is the double-nested loops over
codes.

Now compress_count when it records code pairs will also populate a
bitmap index, where `pairs_index[code1].set(code2)` will indicate that
code2 followed code1 in compressed output.

In the `optimize` loop, we can eliminate tight loop iterations by
accessing `pairse_index[code1].second_codes()` which yields the value
code2 values.

This results in a speedup from ~1ms -> 500micros for the training
benchmark. We're sub-millisecond!

This also makes Miri somewhat palatable to run for all but `test_large`,
so I've re-enabled it for CI (currently it runs in 2.5 minutes. Far cry
from the < 30s build+test step but I guess it's for a good cause)
  • Loading branch information
a10y committed Aug 20, 2024
1 parent b891677 commit d7e836c
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 23 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/miri.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Miri

on:
push:
branches: ["develop"]
pull_request: {}
workflow_dispatch: {}

permissions:
actions: read
contents: read

jobs:
miri:
name: "miri"
runs-on: ubuntu-latest
env:
RUST_BACKTRACE: 1
MIRIFLAGS: -Zmiri-strict-provenance -Zmiri-symbolic-alignment-check -Zmiri-backtrace=full
steps:
- uses: actions/checkout@v4

- name: Rust Version
id: rust-version
shell: bash
run: echo "version=$(cat rust-toolchain.toml | grep channel | awk -F'\"' '{print $2}')" >> $GITHUB_OUTPUT

- name: Rust Toolchain
id: rust-toolchain
uses: dtolnay/rust-toolchain@master
if: steps.rustup-cache.outputs.cache-hit != 'true'
with:
toolchain: "${{ steps.rust-version.outputs.version }}"
components: miri

- name: Rust Dependency Cache
uses: Swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/develop' }}
shared-key: "shared" # To allow reuse across jobs

- name: Rust Compile Cache
uses: mozilla-actions/sccache-action@v0.0.5
- name: Rust Compile Cache Config
shell: bash
run: |
echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV
echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV
echo "CARGO_INCREMENTAL=0" >> $GITHUB_ENV
- name: Run tests with Miri
run: cargo miri test
162 changes: 143 additions & 19 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,88 @@ use std::collections::BinaryHeap;

use crate::{Compressor, Symbol, ESCAPE_CODE, MAX_CODE};

/// Bitmap that only works for values up to 512
#[derive(Clone, Copy, Debug, Default)]
struct CodesBitmap {
codes: [u64; 8],
}

assert_sizeof!(CodesBitmap => 64);

impl CodesBitmap {
/// Set the indicated bit. Must be between 0 and [`MAX_CODE`][crate::MAX_CODE].
pub(crate) fn set(&mut self, index: usize) {
debug_assert!(index <= MAX_CODE as usize, "code cannot exceed {MAX_CODE}");

let map = index >> 6;
self.codes[map] |= 1 << (index % 64);
}

/// Get all codes set in this bitmap
pub(crate) fn codes(&self) -> CodesIterator {
CodesIterator {
inner: self,
index: 0,
block: self.codes[0],
reference: 0,
}
}
}

struct CodesIterator<'a> {
inner: &'a CodesBitmap,
index: usize,
block: u64,
reference: usize,
}

impl<'a> Iterator for CodesIterator<'a> {
type Item = u16;

fn next(&mut self) -> Option<Self::Item> {
// If current is zero, advance to next non-zero block
while self.block == 0 {
self.index += 1;
if self.index >= 8 {
return None;
}
self.block = self.inner.codes[self.index];
self.reference = self.index * 64;
}

if self.block == 0 {
return None;
}

// Find the next set bit in the current block.
let position = self.block.trailing_zeros() as usize;
let code = self.reference + position;

// The next iteration will calculate with reference to the returned code + 1
self.reference = code + 1;
self.block = if position == 63 {
0
} else {
self.block >> (1 + position)
};

Some(code as u16)
}
}

#[derive(Debug, Clone)]
struct Counter {
/// Frequency count for each code.
counts1: Vec<usize>,

/// Frequency count for each code-pair.
counts2: Vec<usize>,

/// Bitmap index of pairs that have been set.
///
/// `pair_index[code1].codes()` yields an iterator that can
/// be used to find all possible codes that follow `codes1`.
pair_index: Vec<CodesBitmap>,
}

const COUNTS1_SIZE: usize = MAX_CODE as usize;
Expand All @@ -28,16 +103,7 @@ impl Counter {
Self {
counts1: vec![0; COUNTS1_SIZE],
counts2: vec![0; COUNTS2_SIZE],
}
}

/// reset
pub fn reset(&mut self) {
for idx in 0..COUNTS1_SIZE {
self.counts1[idx] = 0;
}
for idx in 0..COUNTS2_SIZE {
self.counts2[idx] = 0;
pair_index: vec![CodesBitmap::default(); COUNTS1_SIZE],
}
}

Expand All @@ -50,6 +116,7 @@ impl Counter {
fn record_count2(&mut self, code1: u16, code2: u16) {
let idx = (code1 as usize) * 511 + (code2 as usize);
self.counts2[idx] += 1;
self.pair_index[code1 as usize].set(code2 as usize);
}

#[inline]
Expand All @@ -62,12 +129,24 @@ impl Counter {
let idx = (code1 as usize) * 511 + (code2 as usize);
self.counts2[idx]
}

/// Returns an iterator over the codes that have been observed
/// to follow `code1`.
///
/// This is the set of all values `code2` where there was
/// previously a call to `self.record_count2(code1, code2)`.
fn second_codes(&self, code1: u16) -> CodesIterator {
self.pair_index[code1 as usize].codes()
}
}

/// The number of generations used for training. This is taken from the [FSST paper].
///
/// [FSST paper]: https://www.vldb.org/pvldb/vol13/p2649-boncz.pdf
#[cfg(not(miri))]
const MAX_GENERATIONS: usize = 5;
#[cfg(miri)]
const MAX_GENERATIONS: usize = 2;

impl Compressor {
/// Build and train a `Compressor` from a sample corpus of text.
Expand All @@ -87,14 +166,13 @@ impl Compressor {
return compressor;
}

let mut counter = Counter::new();

for _generation in 0..(MAX_GENERATIONS - 1) {
let mut counter = Counter::new();
compressor.compress_count(sample, &mut counter);
compressor = compressor.optimize(&counter, true);
counter.reset();
}

let mut counter = Counter::new();
compressor.compress_count(sample, &mut counter);
compressor.optimize(&counter, true)
}
Expand Down Expand Up @@ -142,9 +220,16 @@ impl Compressor {
fn optimize(&self, counters: &Counter, include_ascii: bool) -> Self {
let mut res = Compressor::default();
let mut pqueue = BinaryHeap::with_capacity(65_536);

for code1 in 0u16..(256u16 + self.n_symbols as u16) {
let symbol1 = self.symbols[code1 as usize];
let mut gain = counters.count1(code1) * symbol1.len();
let count = counters.count1(code1);
// If count is zero, we can skip the whole inner loop.
if count == 0 {
continue;
}

let mut gain = count * symbol1.len();
// NOTE: use heuristic from C++ implementation to boost the gain of single-byte symbols.
// This helps to reduce exception counts.
if code1 < 256 {
Expand All @@ -157,10 +242,10 @@ impl Compressor {
});
}

for code2 in 0u16..(256u16 + self.n_symbols as u16) {
for code2 in counters.second_codes(code1) {
let symbol2 = &self.symbols[code2 as usize];
// If either symbol is zero-length, or if merging would yield a symbol of
// length greater than 8, skip.

// If merging would yield a symbol of length greater than 8, skip.
if symbol1.len() + symbol2.len() > 8 {
continue;
}
Expand Down Expand Up @@ -247,8 +332,7 @@ impl Ord for Candidate {

#[cfg(test)]
mod test {

use crate::{Compressor, ESCAPE_CODE};
use crate::{builder::CodesBitmap, Compressor, ESCAPE_CODE};

#[test]
fn test_builder() {
Expand Down Expand Up @@ -282,4 +366,44 @@ mod test {
]
);
}

#[test]
fn test_bitmap() {
let mut map = CodesBitmap::default();
map.set(10);
map.set(100);
map.set(500);

let codes: Vec<u16> = map.codes().collect();
assert_eq!(codes, vec![10u16, 100, 500]);

// empty case
let map = CodesBitmap::default();
assert_eq!(map.codes().collect::<Vec<_>>(), vec![]);

// edge case: first bit in each block is set
let mut map = CodesBitmap::default();
(0..8).for_each(|i| map.set(64 * i));
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..8).map(|i| 64 * i).collect::<Vec<_>>(),
);

// Full bitmap case. There are only 512 values, so test them all
let mut map = CodesBitmap::default();
for i in 0..512 {
map.set(i);
}
assert_eq!(
map.codes().collect::<Vec<_>>(),
(0u16..512u16).collect::<Vec<_>>()
);
}

#[test]
#[should_panic(expected = "code cannot exceed")]
fn test_bitmap_invalid() {
let mut map = CodesBitmap::default();
map.set(512);
}
}
5 changes: 1 addition & 4 deletions tests/correctness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,13 @@ fn test_one_byte() {

#[test]
fn test_zeros() {
println!("training zeros");
let training_data: Vec<u8> = vec![0, 1, 2, 3, 4, 0];
let trained = Compressor::train(&training_data);
println!("compressing with zeros");
let compressed = trained.compress(&[4, 0]);
println!("decomperssing with zeros");
assert_eq!(trained.decompressor().decompress(&compressed), &[4, 0]);
println!("done");
}

#[cfg_attr(miri, ignore)]
#[test]
fn test_large() {
let corpus: Vec<u8> = DECLARATION.bytes().cycle().take(10_240).collect();
Expand Down

0 comments on commit d7e836c

Please sign in to comment.