Skip to content

Commit

Permalink
Refactored all beam search methods in a single function, addition of …
Browse files Browse the repository at this point in the history
…integration tests
  • Loading branch information
guillaume-be committed Jan 24, 2021
1 parent 1fe8c29 commit 56574a1
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 389 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- (BREAKING) Implementation of [Diverse Beam Search](https://arxiv.org/abs/1610.02424). This allows the generation of more diverse sequences within the number of beams. Addition of 2 new fields to the `GenerateConfig` that are propagated through all text generation configs (e.g. `TranslationConfig`):
- `num_beam_groups` (`Option<i64>`), indicating the number of sub-beam groups. This must be a divisor of the number of beams.
- `diversity_penalty` (`Option<f64>`), indicating by which amount to penalize common words between beam groups. This will default to 5.5 if not provided. The impact of this diverse beam search is illustrated in the GPT2 integration tests.

### Changed
- (BREAKING) Simplified the input and output of encoder/decoder models to avoid needing to take ownership of the possibly cached encoder hidden state, offering a minor performance improvement for text generation tasks. The model output field for encoder hidden states are now optional, and only returned if the encoder hidden states were not provided for the given forward path. This may be a breaking change for low-level dependencies that manipulate directly the encoder/decoder model outputs.
- (BREAKING) Moved the language models implementation of the `PrivateLanguageGenerator` and `LanguageGenerator` traits (needed to generate text) to the model modules, cleaning up the generation_utils module.

### Fixed
- Updated padding information and addition of position ids for batched GPT2 generation. Prior to this change, inputs that required padding had a lower quality for the text generated.
Expand Down
4 changes: 2 additions & 2 deletions benches/translation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn translation_forward_pass(iters: u64, model: &TranslationModel, data: &[&str])
duration
}

fn summarization_load_model(iters: u64) -> Duration {
fn translation_load_model(iters: u64) -> Duration {
let mut duration = Duration::new(0, 0);
for _i in 0..iters {
let start = Instant::now();
Expand Down Expand Up @@ -96,7 +96,7 @@ fn bench_squad(c: &mut Criterion) {
});

c.bench_function("Load model", |b| {
b.iter_custom(|iters| black_box(summarization_load_model(iters)))
b.iter_custom(|iters| black_box(translation_load_model(iters)))
});
}

Expand Down
Loading

0 comments on commit 56574a1

Please sign in to comment.