Skip to content

Commit

Permalink
Generation traits simplification (guillaume-be#339)
Browse files Browse the repository at this point in the history
* - Remove LMHeadModel trait (integrate with PrivateLanguageGenerator)
- Simplify PrivateLanguageGenerator trait definition (no longer requires defined by objects implementing `LMHeadModel`, `Vocab` and `Tokenizer` traits)

* - Removed BART duplicated code, updated docs

* - Fixed BART-based model incorrect order of generation arguments

* - Updated changelog

* Fixed Clippy warning
  • Loading branch information
guillaume-be committed Mar 17, 2023
1 parent c448862 commit b05ec7b
Show file tree
Hide file tree
Showing 35 changed files with 718 additions and 1,527 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format

## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)

## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.0.0"
tch = "~0.10.1"
tch = "~0.10"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
Expand All @@ -88,6 +88,6 @@ anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "=0.10.0"
torch-sys = "=0.10"
tempfile = "3"
itertools = "0.10"
1 change: 0 additions & 1 deletion examples/generation_gptj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ use tch::Device;
/// ```
///
/// [gpt-j-6B-float16]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16
///
fn main() -> anyhow::Result<()> {
// Resources paths

Expand Down
173 changes: 52 additions & 121 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::RobertaVocab;
use rust_tokenizers::tokenizer::TruncationStrategy;

use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
Expand Down Expand Up @@ -826,7 +824,7 @@ impl BartForSequenceClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();;
/// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
Expand Down Expand Up @@ -891,110 +889,6 @@ impl BartForSequenceClassification {
}
}

impl LMHeadModel for BartForConditionalGeneration {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for BART
/// * `token_type_ids` - Unused for BART
/// * `position_ids` - Unused for BART
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation_utils::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ```
fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
cached_layer_states,
train,
),

Cache::None => self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with BART Model".into(),
));
}
};

let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
cache: Cache::BARTCache(base_model_output.cache),
})
}
}

/// Container holding a BART model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for classification or language modeling tasks)
Expand Down Expand Up @@ -1143,12 +1037,7 @@ impl BartGenerator {
}
}

impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
fn get_model(&self) -> &BartForConditionalGeneration {
&self.model
}
impl PrivateLanguageGenerator for BartGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
Expand Down Expand Up @@ -1183,6 +1072,51 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
self.max_position_embeddings
}

fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
cached_layer_states,
train,
),

Cache::None => self.model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
None,
train,
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with BART Model".into(),
));
}
};

Ok(LMModelOutput {
lm_logits: base_model_output.decoder_output,
cache: Cache::BARTCache(base_model_output.cache),
})
}

fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
Expand All @@ -1203,7 +1137,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}

fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
Some(self.model.encode(input_ids, attention_mask))
}

fn prepare_inputs_for_generation<'a>(
Expand Down Expand Up @@ -1312,10 +1246,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}

impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer>
for BartGenerator
{
}
impl LanguageGenerator for BartGenerator {}

#[cfg(test)]
mod test {
Expand Down
2 changes: 1 addition & 1 deletion src/bart/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!
//! Implementation of the BART language model ([BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) Lewis, Liu, Goyal, Ghazvininejad, Mohamed, Levy, Stoyanov, Zettlemoyer, 2019).
//! The base model is implemented in the `bart_model::BartModel` struct. The model also includes a language model head: `bart_model::BartForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
Expand Down
Loading

0 comments on commit b05ec7b

Please sign in to comment.