From b05ec7b24ff93fc8e8ef4f3fa15a358b8b83398d Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Fri, 17 Mar 2023 16:21:37 +0000 Subject: [PATCH] Generation traits simplification (#339) * - Remove LMHeadModel trait (integrate with PrivateLanguageGenerator) - Simplify PrivateLanguageGenerator trait definition (no longer requires defined by objects implementing `LMHeadModel`, `Vocab` and `Tokenizer` traits) * - Removed BART duplicated code, updated docs * - Fixed BART-based model incorrect order of generation arguments * - Updated changelog * Fixed Clippy warning --- CHANGELOG.md | 1 + Cargo.toml | 4 +- examples/generation_gptj.rs | 1 - src/bart/bart_model.rs | 173 ++++++++------------------- src/bart/mod.rs | 2 +- src/gpt2/gpt2_model.rs | 164 ++++++++----------------- src/gpt2/mod.rs | 2 +- src/gpt_j/gpt_j_model.rs | 144 +++++++++------------- src/gpt_j/mod.rs | 2 +- src/gpt_neo/gpt_neo_model.rs | 103 ++++++++-------- src/longt5/longt5_model.rs | 184 +++++++++-------------------- src/longt5/mod.rs | 2 +- src/m2m_100/m2m_100_model.rs | 172 ++++++++------------------- src/m2m_100/mod.rs | 2 +- src/marian/marian_model.rs | 176 ++++++++------------------- src/marian/mod.rs | 2 +- src/mbart/mbart_model.rs | 171 ++++++++------------------- src/mbart/mod.rs | 2 +- src/openai_gpt/mod.rs | 2 +- src/openai_gpt/openai_gpt_model.rs | 49 +++++--- src/pegasus/mod.rs | 2 +- src/pegasus/pegasus_model.rs | 179 ++++++++-------------------- src/pipelines/generation_utils.rs | 118 +++--------------- src/prophetnet/mod.rs | 2 +- src/prophetnet/prophetnet_model.rs | 123 +++++++++---------- src/reformer/mod.rs | 2 +- src/reformer/reformer_model.rs | 95 +++++++-------- src/t5/mod.rs | 2 +- src/t5/t5_model.rs | 180 ++++++++-------------------- src/xlnet/mod.rs | 2 +- src/xlnet/xlnet_model.rs | 150 +++++++---------------- tests/distilgpt2.rs | 14 +-- tests/gpt2.rs | 14 +-- tests/gpt_j.rs | 2 +- tests/openai_gpt.rs | 2 +- 35 files changed, 718 insertions(+), 1527 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cf7fb482e..ae623c31b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. The format ## Changed - Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer. +- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator) ## Fixed - MIN/MAX computation for float-like (was set to infinity instead of min/max) diff --git a/Cargo.toml b/Cargo.toml index 6a3e38b5d..2b88eb0ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,7 @@ features = ["doc-only"] [dependencies] rust_tokenizers = "8.0.0" -tch = "~0.10.1" +tch = "~0.10" serde_json = "1" serde = { version = "1", features = ["derive"] } ordered-float = "3" @@ -88,6 +88,6 @@ anyhow = "1" csv = "1" criterion = "0.4" tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] } -torch-sys = "=0.10.0" +torch-sys = "=0.10" tempfile = "3" itertools = "0.10" diff --git a/examples/generation_gptj.rs b/examples/generation_gptj.rs index 391d53dd2..7eb521553 100644 --- a/examples/generation_gptj.rs +++ b/examples/generation_gptj.rs @@ -44,7 +44,6 @@ use tch::Device; /// ``` /// /// [gpt-j-6B-float16]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16 -/// fn main() -> anyhow::Result<()> { // Resources paths diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 1b54c9656..b8c7a99a4 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -21,12 +21,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::{RobertaTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::RobertaVocab; +use rust_tokenizers::tokenizer::TruncationStrategy; + use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; @@ -826,7 +824,7 @@ impl BartForSequenceClassification { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BartConfig::from_file(config_path); - /// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap();; + /// # let bart_model: BartForSequenceClassification = BartForSequenceClassification::new(&vs.root(), &config).unwrap(); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -891,110 +889,6 @@ impl BartForSequenceClassification { } } -impl LMHeadModel for BartForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for BART - /// * `token_type_ids` - Unused for BART - /// * `position_ids` - Unused for BART - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BartCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::pipelines::generation_utils::LMHeadModel; - /// use rust_bert::bart::{BartForConditionalGeneration, BartConfig}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = BartConfig::from_file(config_path); - /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// bart_model - /// .forward_t(Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// false) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - cached_layer_states, - train, - ), - - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with BART Model".into(), - )); - } - }; - - let lm_logits = base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None); - Ok(LMModelOutput { - lm_logits, - cache: Cache::BARTCache(base_model_output.cache), - }) - } -} - /// Container holding a BART model output. The decoder output may hold the hidden state of /// the last layer of the decoder, or may hold logits for a custom head module after the /// decoder (e.g. for classification or language modeling tasks) @@ -1143,12 +1037,7 @@ impl BartGenerator { } } -impl PrivateLanguageGenerator - for BartGenerator -{ - fn get_model(&self) -> &BartForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for BartGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -1183,6 +1072,51 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + cached_layer_states, + train, + ), + + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with BART Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::BARTCache(base_model_output.cache), + }) + } + fn prepare_scores_for_generation( &self, scores: &mut Tensor, @@ -1203,7 +1137,7 @@ impl PrivateLanguageGenerator) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -1312,10 +1246,7 @@ impl PrivateLanguageGenerator - for BartGenerator -{ -} +impl LanguageGenerator for BartGenerator {} #[cfg(test)] mod test { diff --git a/src/bart/mod.rs b/src/bart/mod.rs index 9ca0be400..09631991a 100644 --- a/src/bart/mod.rs +++ b/src/bart/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the BART language model ([BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461) Lewis, Liu, Goyal, Ghazvininejad, Mohamed, Levy, Stoyanov, Zettlemoyer, 2019). //! The base model is implemented in the `bart_model::BartModel` struct. The model also includes a language model head: `bart_model::BartForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index cbf442dcf..0431de783 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -20,12 +20,8 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::Gpt2Tokenizer; -use rust_tokenizers::vocab::Gpt2Vocab; use serde::{Deserialize, Serialize}; use std::borrow::{Borrow, BorrowMut}; use tch::kind::Kind::Int64; @@ -529,118 +525,26 @@ impl GPT2LMHeadModel { GPT2LMHeadModel { transformer } } -} -impl LMHeadModel for GPT2LMHeadModel { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`) - /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings. - /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input. - /// * `_encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*). Unused for GPT2 - /// * `_decoder_input_ids` - Optional tensor of shape (*batch size*, *target_sequence_length*). Unused for GPT2 - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `Gpt2Cache` made of `Option>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*) - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config}; - /// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = Gpt2Config::from_file(config_path); - /// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config); - /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mut past: Vec = Vec::with_capacity(config.n_layer as usize); - /// for _ in 0..config.n_layer as usize { - /// past.push(Tensor::rand( - /// &[ - /// 2, - /// batch_size, - /// config.n_head, - /// past_sequence_length, - /// config.n_embd / config.n_head, - /// ], - /// (Double, device), - /// )) - /// } - /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) - /// .expand(&[batch_size, sequence_length], true); - /// - /// let model_output = no_grad(|| { - /// gpt2_model - /// .forward_t( - /// Some(&input_tensor), - /// Cache::GPT2Cache(Some(past)), - /// Some(&attention_mask), - /// Some(&token_type_ids), - /// Some(&position_ids), - /// None, - /// None, - /// None, - /// false, - /// ) - /// .unwrap() - /// }); - /// ``` - fn forward_t( + pub fn forward_t( &self, input_ids: Option<&Tensor>, - layer_past: Cache, + layer_past: Option<&Vec>, attention_mask: Option<&Tensor>, token_type_ids: Option<&Tensor>, position_ids: Option<&Tensor>, input_embeds: Option<&Tensor>, - _encoder_outputs: Option<&Tensor>, - _decoder_input_ids: Option<&Tensor>, train: bool, ) -> Result { - let base_model_output = match layer_past { - Cache::GPT2Cache(layer_past) => self.transformer.forward_t( - input_ids, - layer_past.as_ref(), - attention_mask, - token_type_ids, - position_ids, - input_embeds, - train, - ), - Cache::None => self.transformer.forward_t( - input_ids, - None, - attention_mask, - token_type_ids, - position_ids, - input_embeds, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with GPT2 Model".into(), - )); - } - }?; + let base_model_output = self.transformer.forward_t( + input_ids, + layer_past, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + )?; let lm_logits = base_model_output .output @@ -769,10 +673,7 @@ impl GPT2Generator { } } -impl PrivateLanguageGenerator for GPT2Generator { - fn get_model(&self) -> &GPT2LMHeadModel { - &self.model - } +impl PrivateLanguageGenerator for GPT2Generator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -807,6 +708,43 @@ impl PrivateLanguageGenerator for GPT self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + match layer_past { + Cache::GPT2Cache(layer_past) => self.model.forward_t( + input_ids, + layer_past.as_ref(), + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + None, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + _ => Err(RustBertError::ValueError( + "Cache not compatible with GPT2 Model".into(), + )), + } + } + fn prepare_inputs_for_generation<'a>( &self, input_ids: Tensor, @@ -875,4 +813,4 @@ impl PrivateLanguageGenerator for GPT } } -impl LanguageGenerator for GPT2Generator {} +impl LanguageGenerator for GPT2Generator {} diff --git a/src/gpt2/mod.rs b/src/gpt2/mod.rs index b86725cb8..4e048ff02 100644 --- a/src/gpt2/mod.rs +++ b/src/gpt2/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the GPT2 language model ([Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) Radford, Wu, Child, Luan, Amodei, Sutskever 2019). //! The base model is implemented in the `gpt2_model::Gpt2Model` struct. The model also includes a language model head: `gpt2_model::GPT2LMHeadModel` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/gpt_j/gpt_j_model.rs b/src/gpt_j/gpt_j_model.rs index 073eb756a..527ef95e3 100644 --- a/src/gpt_j/gpt_j_model.rs +++ b/src/gpt_j/gpt_j_model.rs @@ -20,12 +20,8 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::Gpt2Tokenizer; -use rust_tokenizers::vocab::Gpt2Vocab; use serde::{Deserialize, Serialize}; use std::borrow::{Borrow, BorrowMut}; use tch::nn::{embedding, Linear}; @@ -46,7 +42,7 @@ pub struct GptJMergesResources; /// Model weights for Rust are not available out of the box for GPT-J but can be created /// simply with the following command: /// -/// ``` +/// ```ignore /// python utils/convert_model.py path/to/gpt_j/pytorch_model.bin /// ``` /// @@ -57,7 +53,6 @@ pub struct GptJMergesResources; /// /// [gpt-j-6B]: https://huggingface.co/EleutherAI/gpt-j-6B/tree/main /// [gpt-j-6B-float16]:https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16 -/// impl GptJModelResources { pub const GPT_J_TINY_RANDOM: (&'static str, &'static str) = ( "gpt-j-tiny-random/model", @@ -335,7 +330,7 @@ impl GptJModel { /// gpt_j_model /// .forward_t( /// Some(&input_tensor), - /// Some(&past), + /// Some(past), /// Some(&attention_mask), /// Some(&token_type_ids), /// None, @@ -450,7 +445,7 @@ impl GptJLMHeadModel { /// # Example /// /// ```no_run - /// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; + /// use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel}; /// use rust_bert::Config; /// use std::path::Path; /// use tch::{nn, Device}; @@ -483,82 +478,8 @@ impl GptJLMHeadModel { lm_head, } } -} -impl LMHeadModel for GptJLMHeadModel { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`) - /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings. - /// * `_position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input. - /// * `_encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*). Unused for GPT-J - /// * `_decoder_input_ids` - Optional tensor of shape (*batch size*, *target_sequence_length*). Unused for GPT_J - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `GptJCache` made of `Option>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*) - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; - /// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = GptJConfig::from_file(config_path); - /// # let mut gpt_j_model: GptJLMHeadModel = GptJLMHeadModel::new(&vs.root(), &config); - /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mut past: Vec = Vec::with_capacity(config.n_layer as usize); - /// for _ in 0..config.n_layer as usize { - /// past.push(Tensor::rand( - /// &[ - /// 2, - /// batch_size, - /// config.n_head, - /// past_sequence_length, - /// config.n_embd / config.n_head, - /// ], - /// (Double, device), - /// )) - /// } - /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) - /// .expand(&[batch_size, sequence_length], true); - /// - /// let model_output = no_grad(|| { - /// gpt_j_model - /// .forward_t( - /// Some(&input_tensor), - /// Cache::GPTJCache(Some(past)), - /// Some(&attention_mask), - /// Some(&token_type_ids), - /// None, - /// None, - /// None, - /// None, - /// false, - /// ) - /// .unwrap() - /// }); - /// ``` - fn forward_t( + pub fn forward_t( &self, input_ids: Option<&Tensor>, layer_past: Cache, @@ -648,7 +569,7 @@ impl GptJGenerator { /// use rust_bert::pipelines::generation_utils::GenerateConfig; /// /// let generate_config = GenerateConfig { - /// max_length: 30, + /// max_length: Some(30), /// do_sample: true, /// num_beams: 5, /// temperature: 1.1, @@ -728,10 +649,7 @@ impl GptJGenerator { } } -impl PrivateLanguageGenerator for GptJGenerator { - fn get_model(&self) -> &GptJLMHeadModel { - &self.model - } +impl PrivateLanguageGenerator for GptJGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -766,6 +684,52 @@ impl PrivateLanguageGenerator for Gpt self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match layer_past { + Cache::GPTJCache(layer_past) => self.model.transformer.forward_t( + input_ids, + layer_past, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + Cache::None => self.model.transformer.forward_t( + input_ids, + None, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with GPT-J Model".into(), + )); + } + }?; + + let lm_logits = base_model_output.output.apply(&self.model.lm_head); + + Ok(LMModelOutput { + lm_logits, + cache: Cache::GPTJCache(base_model_output.cache), + }) + } + fn prepare_inputs_for_generation<'a>( &self, input_ids: Tensor, @@ -833,4 +797,4 @@ impl PrivateLanguageGenerator for Gpt } } -impl LanguageGenerator for GptJGenerator {} +impl LanguageGenerator for GptJGenerator {} diff --git a/src/gpt_j/mod.rs b/src/gpt_j/mod.rs index ef375ed8f..91618d685 100644 --- a/src/gpt_j/mod.rs +++ b/src/gpt_j/mod.rs @@ -9,7 +9,7 @@ //! # //! use tch::{nn, Device}; //! # use std::path::PathBuf; -//! use rust_bert::gpt_j::{GptJLMHeadModel, GptJConfig}; +//! use rust_bert::gpt_j::{GptJConfig, GptJLMHeadModel}; //! use rust_bert::resources::{LocalResource, ResourceProvider}; //! use rust_bert::Config; //! use rust_tokenizers::tokenizer::Gpt2Tokenizer; diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 97f50fed0..d49914d89 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -18,12 +18,8 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Activation, Config, RustBertError}; -use rust_tokenizers::tokenizer::Gpt2Tokenizer; -use rust_tokenizers::vocab::Gpt2Vocab; use serde::{Deserialize, Serialize}; use std::borrow::{Borrow, BorrowMut}; use tch::{nn, Kind, Tensor}; @@ -570,52 +566,6 @@ impl GptNeoForCausalLM { } } -impl LMHeadModel for GptNeoForCausalLM { - fn forward_t( - &self, - input_ids: Option<&Tensor>, - layer_past: Cache, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor>, - input_embeds: Option<&Tensor>, - _encoder_outputs: Option<&Tensor>, - _decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match layer_past { - Cache::GPTNeoCache(layer_past) => self.forward_t( - input_ids, - input_embeds, - token_type_ids, - position_ids, - layer_past, - attention_mask, - train, - ), - Cache::None => self.forward_t( - input_ids, - input_embeds, - token_type_ids, - position_ids, - None, - attention_mask, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with GPT-Neo Model".into(), - )); - } - }?; - - Ok(LMModelOutput { - lm_logits: base_model_output.lm_logits, - cache: Cache::GPTNeoCache(base_model_output.next_cache), - }) - } -} - /// Container for the GPT-Neo model output. pub struct GptNeoModelOutput { /// Last hidden states from the model @@ -743,10 +693,7 @@ impl GptNeoGenerator { } } -impl PrivateLanguageGenerator for GptNeoGenerator { - fn get_model(&self) -> &GptNeoForCausalLM { - &self.model - } +impl PrivateLanguageGenerator for GptNeoGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -777,10 +724,54 @@ impl PrivateLanguageGenerator for G fn get_decoder_start_id(&self) -> Option { self.decoder_start_id } + fn get_max_positions_embeddings(&self) -> i64 { self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match layer_past { + Cache::GPTNeoCache(layer_past) => self.model.forward_t( + input_ids, + input_embeds, + token_type_ids, + position_ids, + layer_past, + attention_mask, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + input_embeds, + token_type_ids, + position_ids, + None, + attention_mask, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with GPT-Neo Model".into(), + )); + } + }?; + + Ok(LMModelOutput { + lm_logits: base_model_output.lm_logits, + cache: Cache::GPTNeoCache(base_model_output.next_cache), + }) + } fn prepare_inputs_for_generation<'a>( &self, input_ids: Tensor, @@ -851,4 +842,4 @@ impl PrivateLanguageGenerator for G } } -impl LanguageGenerator for GptNeoGenerator {} +impl LanguageGenerator for GptNeoGenerator {} diff --git a/src/longt5/longt5_model.rs b/src/longt5/longt5_model.rs index 53fb9d2cf..5dfa2b7fd 100644 --- a/src/longt5/longt5_model.rs +++ b/src/longt5/longt5_model.rs @@ -16,13 +16,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::t5::{FeedForwardProj, T5Config, T5ModelOutput, TaskSpecificParams}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::{T5Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::T5Vocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use tch::nn::{embedding, LinearConfig}; @@ -548,124 +545,6 @@ impl LongT5ForConditionalGeneration { } } -impl LMHeadModel for LongT5ForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `cache` - `Cache` object containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for LongT5 - /// * `token_type_ids` - Unused for LongT5 - /// * `position_ids` - Unused for LongT5 - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `T5Cache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = LongT5Config::from_file(config_path); - /// # let longt5_model: LongT5ForConditionalGeneration = LongT5ForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// longt5_model.forward_t( - /// Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// None, - /// None, - /// false, - /// ) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::LongT5Cache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - encoder_outputs, - decoder_input_ids, - None, - None, - None, - cached_layer_states, - train, - )?, - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - encoder_outputs, - decoder_input_ids, - None, - None, - None, - None, - train, - )?, - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with LongT5 Model".into(), - )); - } - }; - - let lm_logits = if self.tie_word_embeddings { - base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None) - * (self.model_dim.powf(-0.5)) - } else { - base_model_output - .decoder_output - .apply(self.lm_head.as_ref().unwrap()) - }; - - Ok(LMModelOutput { - lm_logits, - cache: Cache::LongT5Cache(base_model_output.next_cache), - }) - } -} - /// Container holding a LongT5 model output. pub type LongT5ModelOutput = T5ModelOutput; @@ -742,12 +621,7 @@ impl LongT5Generator { } } -impl PrivateLanguageGenerator - for LongT5Generator -{ - fn get_model(&self) -> &LongT5ForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for LongT5Generator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -782,8 +656,56 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::LongT5Cache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + cached_layer_states, + train, + )?, + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + None, + train, + )?, + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with LongT5 Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::LongT5Cache(base_model_output.next_cache), + }) + } + fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -891,4 +813,4 @@ impl PrivateLanguageGenerator for LongT5Generator {} +impl LanguageGenerator for LongT5Generator {} diff --git a/src/longt5/mod.rs b/src/longt5/mod.rs index 091040fdb..fa435e2eb 100644 --- a/src/longt5/mod.rs +++ b/src/longt5/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the LongT5 language model ([LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) Guo, Ainslie, Uthus, Ontanon, Ni, Sung, Yang, 2021). //! The base model is implemented in the `longt5_model::LongT5Model` struct. This model includes a language model head: `longt5_model::LongT5ForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 59e4b58dd..74466aa45 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -18,13 +18,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::pipelines::translation::Language; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::{M2M100Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::M2M100Vocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use std::borrow::Borrow; use tch::nn::{embedding, EmbeddingConfig}; use tch::{nn, Kind, Tensor}; @@ -458,109 +455,6 @@ impl M2M100ForConditionalGeneration { } } -impl LMHeadModel for M2M100ForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for M2M100 - /// * `token_type_ids` - Unused for M2M100 - /// * `position_ids` - Unused for M2M100 - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BARTCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::pipelines::generation_utils::LMHeadModel; - /// use rust_bert::m2m_100::{M2M100ForConditionalGeneration, M2M100Config}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = M2M100Config::from_file(config_path); - /// # let m2m100_model: M2M100ForConditionalGeneration = M2M100ForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// m2m100_model - /// .forward_t(Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// false) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - cached_layer_states, - train, - ), - - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with M2M100 Model".into(), - )); - } - }; - - let lm_logits = base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None); - Ok(LMModelOutput { - lm_logits, - cache: Cache::BARTCache(base_model_output.cache), - }) - } -} - /// # Language generation model based on the M2M100 architecture pub struct M2M100Generator { model: M2M100ForConditionalGeneration, @@ -689,12 +583,7 @@ impl M2M100Generator { } } -impl PrivateLanguageGenerator - for M2M100Generator -{ - fn get_model(&self) -> &M2M100ForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for M2M100Generator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -725,11 +614,55 @@ impl PrivateLanguageGenerator Option { self.decoder_start_id } - fn get_max_positions_embeddings(&self) -> i64 { self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + cached_layer_states, + train, + ), + + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with M2M100 Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::BARTCache(base_model_output.cache), + }) + } + fn prepare_scores_for_generation( &self, scores: &mut Tensor, @@ -747,7 +680,7 @@ impl PrivateLanguageGenerator) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -856,10 +789,7 @@ impl PrivateLanguageGenerator - for M2M100Generator -{ -} +impl LanguageGenerator for M2M100Generator {} #[cfg(test)] mod test { @@ -882,7 +812,7 @@ mod test { // Set-up masked LM model let device = Device::cuda_if_available(); - let vs = tch::nn::VarStore::new(device); + let vs = nn::VarStore::new(device); let config = M2M100Config::from_file(config_path); let _: Box = Box::new(M2M100Model::new(vs.root(), &config)); diff --git a/src/m2m_100/mod.rs b/src/m2m_100/mod.rs index b1357f91b..850b920ef 100644 --- a/src/m2m_100/mod.rs +++ b/src/m2m_100/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the M2M-100 language model ([Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) Fan, Bhosale, Schwenk, Ma, El-Kishky, Goyal, Baines, Celebi, Wenzel, Chaudhary, Goyal, Birch, Liptchinsky, Edunov, Grave, Auli, Joulin, 2020). //! The base model is implemented in the `m2m_100::M2M100Model` struct. The model also includes a language model head: `m2m_100::M2M100ForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! This model allows for direct translation between 100 languages. //! The translation capabilities are illustrated in `examples/translation_m2m100`, run with `cargo run --example translation_m2m100`. //! diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 0dde3e096..9a858203a 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -16,13 +16,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::pipelines::translation::Language; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::{MarianTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::MarianVocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use std::borrow::Borrow; use tch::nn::Init; use tch::{nn, Kind, Tensor}; @@ -653,7 +650,8 @@ impl MarianForConditionalGeneration { let lm_logits = base_model_output .decoder_output - .linear::(&self.base_model.embeddings.ws, None); + .linear::(&self.base_model.embeddings.ws, None) + + &self.final_logits_bias; BartModelOutput { decoder_output: lm_logits, ..base_model_output @@ -673,114 +671,6 @@ impl MarianForConditionalGeneration { } } -impl LMHeadModel for MarianForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Unused for BART - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for BART - /// * `token_type_ids` - Unused for BART - /// * `position_ids` - Unused for BART - /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). - /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BartCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::bart::BartConfig; - /// use rust_bert::marian::MarianForConditionalGeneration; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = BartConfig::from_file(config_path); - /// # let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// marian_model.forward_t( - /// Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// false, - /// ) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - cached_layer_states, - train, - ), - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with Marian Model".into(), - )); - } - }; - - let lm_logits = base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None) - + &self.final_logits_bias; - Ok(LMModelOutput { - lm_logits, - cache: Cache::BARTCache(base_model_output.cache), - }) - } -} - /// # Language generation model based on the Marian architecture for machine translation pub struct MarianGenerator { model: MarianForConditionalGeneration, @@ -914,12 +804,7 @@ impl MarianGenerator { } } -impl PrivateLanguageGenerator - for MarianGenerator -{ - fn get_model(&self) -> &MarianForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for MarianGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -954,6 +839,50 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + cached_layer_states, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with Marian Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::BARTCache(base_model_output.cache), + }) + } + fn prepare_scores_for_generation( &self, scores: &mut Tensor, @@ -976,7 +905,7 @@ impl PrivateLanguageGenerator) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -1086,7 +1015,4 @@ impl PrivateLanguageGenerator - for MarianGenerator -{ -} +impl LanguageGenerator for MarianGenerator {} diff --git a/src/marian/mod.rs b/src/marian/mod.rs index 28d76d6c0..a302c954c 100644 --- a/src/marian/mod.rs +++ b/src/marian/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the Marian language model ([Marian: Fast Neural Machine Translation in {C++}](http://www.aclweb.org/anthology/P18-4020) Junczys-Dowmunt, Grundkiewicz, Dwojak, Hoang, Heafield, Neckermann, Seide, Germann, Fikri Aji, Bogoychev, Martins, Birch, 2018). //! The base model is implemented in the `bart_model::BartModel` struct. This model includes a language model head: `marian_model::MarianForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index 471f562fc..7ddbd1899 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -19,13 +19,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::pipelines::translation::Language; use crate::{Activation, Config, RustBertError}; -use rust_tokenizers::tokenizer::{MBart50Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::MBart50Vocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; @@ -717,111 +714,6 @@ impl MBartForSequenceClassification { } } -impl LMHeadModel for MBartForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for MBart - /// * `token_type_ids` - Unused for MBart - /// * `position_ids` - Unused for MBart - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BartCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::pipelines::generation_utils::LMHeadModel; - /// use rust_bert::mbart::{MBartForConditionalGeneration, MBartConfig}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = MBartConfig::from_file(config_path); - /// # let mbart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// mbart_model - /// .forward_t(Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// false) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - cached_layer_states, - train, - ), - - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids, - encoder_outputs, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with MBART Model".into(), - )); - } - }; - - let lm_logits = base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None) - + &self.final_logits_bias; - Ok(LMModelOutput { - lm_logits, - cache: Cache::BARTCache(base_model_output.cache), - }) - } -} - /// Container holding a MBART model output pub type MBartModelOutput = BartModelOutput; @@ -944,12 +836,7 @@ impl MBartGenerator { } } -impl PrivateLanguageGenerator - for MBartGenerator -{ - fn get_model(&self) -> &MBartForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for MBartGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -981,6 +868,51 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + cached_layer_states, + train, + ), + + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with MBART Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::BARTCache(base_model_output.cache), + }) + } + fn get_max_positions_embeddings(&self) -> i64 { self.max_position_embeddings } @@ -1002,7 +934,7 @@ impl PrivateLanguageGenerator) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -1111,10 +1043,7 @@ impl PrivateLanguageGenerator - for MBartGenerator -{ -} +impl LanguageGenerator for MBartGenerator {} #[cfg(test)] mod test { diff --git a/src/mbart/mod.rs b/src/mbart/mod.rs index 72f6fb712..93ebcff97 100644 --- a/src/mbart/mod.rs +++ b/src/mbart/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the MBart language model ([Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) Liu, Gu, Goyal, Li, Edunov, Ghazvininejad, Lewis, Zettlemoyer, 2020). //! The base model is implemented in the `mbart_model::MBartModel` struct. The model also includes a language model head: `mbart_model::MBartForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/openai_gpt/mod.rs b/src/openai_gpt/mod.rs index ab74d0630..e97a16c09 100644 --- a/src/openai_gpt/mod.rs +++ b/src/openai_gpt/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the GPT2 language model ([Improving Language Understanding by Generative Pre-Training](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) Radford, Narasimhan, Salimans, Sutskever 2018). //! The base model is implemented in the `openai_gpt_model::OpenAiGptModel` struct. The model also includes a language model head: `openai_gpt_model::OpenAIGPTLMHeadModel` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index 52adde72d..7cdc15d87 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -19,12 +19,8 @@ use crate::gpt2::Gpt2Config; use crate::openai_gpt::transformer::Block; use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::OpenAiGptTokenizer; -use rust_tokenizers::vocab::OpenAiGptVocab; use std::borrow::{Borrow, BorrowMut}; use tch::kind::Kind::Int64; use tch::nn::embedding; @@ -326,9 +322,7 @@ impl OpenAIGPTLMHeadModel { lm_head, } } -} -impl LMHeadModel for OpenAIGPTLMHeadModel { /// Forward pass through the model /// /// # Arguments @@ -362,7 +356,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel { /// # use tch::kind::Kind::{Int64, Double}; /// use rust_bert::gpt2::Gpt2Config; /// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel; - /// use rust_bert::pipelines::generation_utils::{LMHeadModel, Cache}; + /// use rust_bert::pipelines::generation_utils::Cache; /// # let config_path = Path::new("path/to/config.json"); /// # let vocab_path = Path::new("path/to/vocab.txt"); /// # let device = Device::Cpu; @@ -388,7 +382,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel { /// false).unwrap() /// }); /// ``` - fn forward_t( + pub fn forward_t( &self, input_ids: Option<&Tensor>, _layer_past: Cache, @@ -531,12 +525,7 @@ impl OpenAIGenerator { } } -impl PrivateLanguageGenerator - for OpenAIGenerator -{ - fn get_model(&self) -> &OpenAIGPTLMHeadModel { - &self.model - } +impl PrivateLanguageGenerator for OpenAIGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -570,9 +559,31 @@ impl PrivateLanguageGenerator i64 { self.max_position_embeddings } -} -impl LanguageGenerator - for OpenAIGenerator -{ + fn forward_t( + &self, + input_ids: Option<&Tensor>, + _layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + self.model.forward_t( + input_ids, + _layer_past, + attention_mask, + token_type_ids, + position_ids, + input_embeds, + _encoder_outputs, + _decoder_input_ids, + train, + ) + } } + +impl LanguageGenerator for OpenAIGenerator {} diff --git a/src/pegasus/mod.rs b/src/pegasus/mod.rs index cd7b4b33d..678768911 100644 --- a/src/pegasus/mod.rs +++ b/src/pegasus/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the Pegasus language model ([PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) Zhang, Zhao, Saleh, Liu, 2019). //! The base model is implemented in the `pegasus_model::PegasusModel` struct and leverages an implementation that is broadly similar to BART. The model also includes a language model head: `pegasus_model::PegasusForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index 24569f532..1175ec2ae 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -20,12 +20,9 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::{PegasusTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::PegasusVocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use std::borrow::Borrow; use tch::nn::{embedding, EmbeddingConfig, Init}; use tch::{nn, Tensor}; @@ -430,120 +427,6 @@ impl PegasusForConditionalGeneration { } } -impl LMHeadModel for PegasusForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for Pegasus - /// * `token_type_ids` - Unused for Pegasus - /// * `position_ids` - Unused for Pegasus - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). Must be provided when running in generation mode (e.g. initialized with a BOS token) - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `BartCache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::pipelines::generation_utils::LMHeadModel; - /// use rust_bert::pegasus::{PegasusForConditionalGeneration, PegasusConfig}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = PegasusConfig::from_file(config_path); - /// # let pegasus_model: PegasusForConditionalGeneration = PegasusForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// pegasus_model - /// .forward_t(Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// false) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids.ok_or_else(|| { - RustBertError::ValueError( - "Decoder input ids must be provided for Pegasus language models" - .to_string(), - ) - })?, - encoder_outputs, - None, - cached_layer_states, - train, - ), - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - decoder_input_ids.ok_or_else(|| { - RustBertError::ValueError( - "Decoder input ids must be provided for Pegasus language models" - .to_string(), - ) - })?, - encoder_outputs, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with Pegasus Model".into(), - )); - } - }; - - let lm_logits = base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None) - + &self.final_logits_bias; - Ok(LMModelOutput { - lm_logits, - cache: Cache::BARTCache(base_model_output.cache), - }) - } -} - /// # Language generation model based on the Pegasus architecture pub struct PegasusConditionalGenerator { model: PegasusForConditionalGeneration, @@ -666,12 +549,7 @@ impl PegasusConditionalGenerator { } } -impl PrivateLanguageGenerator - for PegasusConditionalGenerator -{ - fn get_model(&self) -> &PegasusForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for PegasusConditionalGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -706,6 +584,50 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::BARTCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + cached_layer_states, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with Pegasus Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::BARTCache(base_model_output.cache), + }) + } + fn prepare_scores_for_generation( &self, scores: &mut Tensor, @@ -721,7 +643,7 @@ impl PrivateLanguageGenerator) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -833,10 +755,7 @@ impl PrivateLanguageGenerator - for PegasusConditionalGenerator -{ -} +impl LanguageGenerator for PegasusConditionalGenerator {} /// Container holding a Pegasus model output. The decoder output may hold the hidden state of /// the last layer of the decoder, or may hold logits for a custom head module after the diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 71e68c7ff..8c4a601dd 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -66,13 +66,10 @@ //! # ; //! ``` -use rust_tokenizers::tokenizer::Tokenizer; -use rust_tokenizers::vocab::Vocab; use tch::kind::Kind::Int64; use tch::{no_grad, Device, Tensor}; use crate::bart::LayerState as BartLayerState; -use crate::common::error::RustBertError; use crate::common::resources::ResourceProvider; use crate::gpt_j::LayerState as GPTJLayerState; use crate::gpt_neo::LayerState as GPTNeoLayerState; @@ -234,18 +231,18 @@ pub(crate) mod private_generation_utils { use std::collections::HashMap; use std::mem; - use rust_tokenizers::tokenizer::{truncate_sequences, Tokenizer, TruncationStrategy}; - use rust_tokenizers::vocab::Vocab; + use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy}; use rust_tokenizers::TokenIdsWithOffsets; use tch::{nn, Device, Kind, Tensor}; use crate::pipelines::common::TokenizerOption; use crate::pipelines::generation_utils::{ - BeamHypotheses, Cache, GenerateConfig, LMHeadModel, PrefixAllowedFunction, + BeamHypotheses, Cache, GenerateConfig, LMModelOutput, PrefixAllowedFunction, }; use super::ordered_float::OrderedFloat; use crate::common::kind::get_positive_infinity; + use crate::RustBertError; pub struct InternalGenerateOptions<'a> { pub min_length: i64, @@ -283,8 +280,7 @@ pub(crate) mod private_generation_utils { pub token_scores: Option>>, } - pub trait PrivateLanguageGenerator> { - fn get_model(&self) -> &T; + pub trait PrivateLanguageGenerator { fn _get_tokenizer(&self) -> &TokenizerOption; fn get_var_store(&self) -> &nn::VarStore; fn get_var_store_mut(&mut self) -> &mut nn::VarStore; @@ -297,6 +293,19 @@ pub(crate) mod private_generation_utils { fn get_decoder_start_id(&self) -> Option; fn get_max_positions_embeddings(&self) -> i64; + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result; + fn prepare_scores_for_generation( &self, _scores: &mut Tensor, @@ -778,7 +787,6 @@ pub(crate) mod private_generation_utils { attention_mask.copy(), ); let temp = self - .get_model() .forward_t( prepared_input.prepared_input.as_ref(), prepared_input.prepared_past, @@ -1054,7 +1062,6 @@ pub(crate) mod private_generation_utils { attention_mask.copy(), ); let temp = self - .get_model() .forward_t( prepared_input.prepared_input.as_ref(), prepared_input.prepared_past, @@ -1590,9 +1597,7 @@ macro_rules! unpack_config { /// # Common trait for text generation models. /// Main API for text generation -pub trait LanguageGenerator>: - PrivateLanguageGenerator -{ +pub trait LanguageGenerator: PrivateLanguageGenerator { /// Generate text based on a vector of promp texts. /// /// # Arguments @@ -2255,93 +2260,6 @@ impl BeamHypotheses { } } -/// # Language Model trait -/// Shared trait between language generation models (e.g. GPT2, GPT, BART) used in language generation pipelines. -pub trait LMHeadModel { - /// Forward pass through the model. Example provided for GPT2. - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`) - /// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings. - /// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input. - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// # Returns - /// - /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// * `past` - `Option>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*) - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config}; - /// use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = Gpt2Config::from_file(config_path); - /// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config); - /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mut past: Vec = Vec::with_capacity(config.n_layer as usize); - /// for _ in 0..config.n_layer as usize { - /// past.push(Tensor::rand( - /// &[ - /// 2, - /// batch_size, - /// config.n_head, - /// past_sequence_length, - /// config.n_embd / config.n_head, - /// ], - /// (Double, device), - /// )) - /// } - /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) - /// .expand(&[batch_size, sequence_length], true); - /// - /// let model_output = no_grad(|| { - /// gpt2_model - /// .forward_t( - /// Some(&input_tensor), - /// Cache::GPT2Cache(Some(past)), - /// Some(&attention_mask), - /// Some(&token_type_ids), - /// Some(&position_ids), - /// None, - /// None, - /// None, - /// false, - /// ) - /// .unwrap() - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - layer_past: Cache, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor>, - input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result; -} - /// Container holding a language model output for generation tasks pub struct LMModelOutput { /// Logits for each vocab item and position diff --git a/src/prophetnet/mod.rs b/src/prophetnet/mod.rs index c46e5791e..b6a069cd9 100644 --- a/src/prophetnet/mod.rs +++ b/src/prophetnet/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the ProphetNet language model ([ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) Qi, Yan, Gong, Liu, Duan, Chen, Zhang, Zhou, 2020). //! The base model is implemented in the `prophetnet_model::ProphetNetModel` struct. Two language model heads have also been implemented: -//! - Conditional language generation (encoder-decoder architecture): `prophetnet_model::ProphetNetForConditionalGeneration` implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information) +//! - Conditional language generation (encoder-decoder architecture): `prophetnet_model::ProphetNetForConditionalGeneration` implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information) //! - Causal language generation (decoder architecture): `prophetnet_model::ProphetNetForCausalGeneration` //! //! # Model set-up and pre-trained weights loading diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index 8e7e872cd..b91426475 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -13,8 +13,7 @@ use std::borrow::Borrow; use std::collections::HashMap; -use rust_tokenizers::tokenizer::{ProphetNetTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::ProphetNetVocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use serde::{Deserialize, Serialize}; use tch::{nn, Kind, Tensor}; @@ -22,9 +21,7 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::prophetnet::attention::LayerState; use crate::prophetnet::decoder::ProphetNetDecoder; use crate::prophetnet::encoder::ProphetNetEncoder; @@ -585,56 +582,6 @@ impl ProphetNetForConditionalGeneration { } } -impl LMHeadModel for ProphetNetForConditionalGeneration { - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::ProphetNetCache(cached_layer_states) => self.forward_t( - input_ids, - attention_mask, - input_embeds, - decoder_input_ids, - None, - encoder_outputs, - cached_layer_states, - None, - train, - )?, - Cache::None => self.forward_t( - input_ids, - attention_mask, - input_embeds, - decoder_input_ids, - None, - encoder_outputs, - None, - None, - train, - )?, - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with ProphetNet Model".into(), - )); - } - }; - - Ok(LMModelOutput { - lm_logits: base_model_output.logits, - cache: Cache::ProphetNetCache(base_model_output.next_decoder_cache), - }) - } -} - /// # ProphetNet Model for causal generation /// ProphetNet decoder with a vocabulary decoding head /// It is made of the following blocks: @@ -992,16 +939,7 @@ impl ProphetNetConditionalGenerator { } } -impl - PrivateLanguageGenerator< - ProphetNetForConditionalGeneration, - ProphetNetVocab, - ProphetNetTokenizer, - > for ProphetNetConditionalGenerator -{ - fn get_model(&self) -> &ProphetNetForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for ProphetNetConditionalGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -1036,9 +974,57 @@ impl self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::ProphetNetCache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + input_embeds, + decoder_input_ids, + None, + encoder_outputs, + cached_layer_states, + None, + train, + )?, + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + input_embeds, + decoder_input_ids, + None, + encoder_outputs, + None, + None, + train, + )?, + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with ProphetNet Model".into(), + )); + } + }; + + Ok(LMModelOutput { + lm_logits: base_model_output.logits, + cache: Cache::ProphetNetCache(base_model_output.next_decoder_cache), + }) + } + fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option { Some( - self.get_model() + self.model .encode(Some(input_ids), attention_mask, None) .unwrap(), ) @@ -1150,7 +1136,4 @@ impl } } -impl LanguageGenerator - for ProphetNetConditionalGenerator -{ -} +impl LanguageGenerator for ProphetNetConditionalGenerator {} diff --git a/src/reformer/mod.rs b/src/reformer/mod.rs index 5887e683d..5a076771f 100644 --- a/src/reformer/mod.rs +++ b/src/reformer/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the Reformer language model ([Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) Kitaev, kaiser, Levskaya, 2020). //! The base model is implemented in the `reformer_model::ReformerModel` struct. The model also includes a language model head: `reformer_model::ReformerModelWithLMHead` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index 61afbf2a1..de9e02d89 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -14,8 +14,6 @@ use std::borrow::Borrow; use std::collections::HashMap; -use rust_tokenizers::tokenizer::ReformerTokenizer; -use rust_tokenizers::vocab::ReformerVocab; use serde::{Deserialize, Serialize}; use serde_json::Value; use tch::{nn, Device, Kind, Tensor}; @@ -27,9 +25,7 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::reformer::attention::{AttentionType, LayerState}; use crate::reformer::attention_utils::{get_least_common_mult_chunk_len, get_min_chunk_len}; use crate::reformer::embeddings::ReformerEmbeddings; @@ -649,44 +645,6 @@ impl ReformerModelWithLMHead { } } -impl LMHeadModel for ReformerModelWithLMHead { - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - _encoder_outputs: Option<&Tensor>, - _decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let output = match cache { - Cache::ReformerCache(cached_layer_states) => self.forward_t( - input_ids, - None, - None, - attention_mask, - None, - cached_layer_states, - train, - ), - Cache::None => self.forward_t(input_ids, None, None, attention_mask, None, None, train), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with Reformer Model".into(), - )); - } - }?; - - Ok(LMModelOutput { - lm_logits: output.logits, - cache: Cache::ReformerCache(output.next_cache), - }) - } -} - pub struct ReformerClassificationHead { dense: nn::Linear, dropout: Dropout, @@ -1119,12 +1077,7 @@ impl ReformerGenerator { } } -impl PrivateLanguageGenerator - for ReformerGenerator -{ - fn get_model(&self) -> &ReformerModelWithLMHead { - &self.model - } +impl PrivateLanguageGenerator for ReformerGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -1159,6 +1112,45 @@ impl PrivateLanguageGenerator, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + _decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let output = match cache { + Cache::ReformerCache(cached_layer_states) => self.model.forward_t( + input_ids, + None, + None, + attention_mask, + None, + cached_layer_states, + train, + ), + Cache::None => { + self.model + .forward_t(input_ids, None, None, attention_mask, None, None, train) + } + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with Reformer Model".into(), + )); + } + }?; + + Ok(LMModelOutput { + lm_logits: output.logits, + cache: Cache::ReformerCache(output.next_cache), + }) + } + fn prepare_inputs_for_generation<'a>( &self, input_ids: Tensor, @@ -1213,7 +1205,4 @@ impl PrivateLanguageGenerator - for ReformerGenerator -{ -} +impl LanguageGenerator for ReformerGenerator {} diff --git a/src/t5/mod.rs b/src/t5/mod.rs index b13a670e7..ec9004c8d 100644 --- a/src/t5/mod.rs +++ b/src/t5/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the T5 language model ([Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) Raffel, Shazeer, Roberts, Lee, Narang, Matena, Zhou, Li, Liu, 2019). //! The base model is implemented in the `t5_model::T5Model` struct. This model includes a language model head: `t5_model::T5ForConditionalGeneration` -//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information). //! //! # Model set-up and pre-trained weights loading //! diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 190a925e8..3d975dd3d 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -12,8 +12,7 @@ use std::borrow::Borrow; -use rust_tokenizers::tokenizer::{T5Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::T5Vocab; +use rust_tokenizers::tokenizer::TruncationStrategy; use serde::{Deserialize, Serialize}; use tch::nn::{embedding, LinearConfig}; use tch::{nn, Tensor}; @@ -22,9 +21,7 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::pipelines::translation::Language; use crate::t5::attention::LayerState; use crate::t5::encoder::T5Stack; @@ -622,124 +619,6 @@ impl T5ForConditionalGeneration { } } -impl LMHeadModel for T5ForConditionalGeneration { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) - /// * `layer_past` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 - /// * `input_embeds` - Unused for T5 - /// * `token_type_ids` - Unused for T5 - /// * `position_ids` - Unused for T5 - /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. - /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `T5Cache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for - /// both the self attention and the encoder cross attention of each layer of the decoder. - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::t5::{T5Config, T5ForConditionalGeneration}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = T5Config::from_file(config_path); - /// # let t5_model: T5ForConditionalGeneration = T5ForConditionalGeneration::new(&vs.root(), &config); - /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); - /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); - /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); - /// let encoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// let decoder_attention_mask = - /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); - /// - /// let model_output = no_grad(|| { - /// t5_model.forward_t( - /// Some(&input_tensor), - /// Some(&encoder_attention_mask), - /// None, - /// Some(&target_tensor), - /// Some(&decoder_attention_mask), - /// None, - /// None, - /// None, - /// false, - /// ) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - cache: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - let base_model_output = match cache { - Cache::T5Cache(cached_layer_states) => self.base_model.forward_t( - input_ids, - attention_mask, - encoder_outputs, - decoder_input_ids, - None, - None, - None, - cached_layer_states, - train, - ), - Cache::None => self.base_model.forward_t( - input_ids, - attention_mask, - encoder_outputs, - decoder_input_ids, - None, - None, - None, - None, - train, - ), - _ => { - return Err(RustBertError::ValueError( - "Cache not compatible with T5 Model".into(), - )); - } - }; - - let lm_logits = if self.tie_word_embeddings { - base_model_output - .decoder_output - .linear::(&self.base_model.embeddings.ws, None) - * (self.model_dim.powf(-0.5)) - } else { - base_model_output - .decoder_output - .apply(self.lm_head.as_ref().unwrap()) - }; - - Ok(LMModelOutput { - lm_logits, - cache: Cache::T5Cache(base_model_output.next_cache), - }) - } -} - /// # T5 for sentence embeddings /// Transformer usable in [`SentenceEmbeddingsModel`](crate::pipelines::sentence_embeddings::SentenceEmbeddingsModel). pub struct T5ForSentenceEmbeddings { @@ -912,10 +791,7 @@ impl T5Generator { } } -impl PrivateLanguageGenerator for T5Generator { - fn get_model(&self) -> &T5ForConditionalGeneration { - &self.model - } +impl PrivateLanguageGenerator for T5Generator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -949,9 +825,55 @@ impl PrivateLanguageGenerator fn get_max_positions_embeddings(&self) -> i64 { self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::T5Cache(cached_layer_states) => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + cached_layer_states, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + None, + train, + ), + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with T5 Model".into(), + )); + } + }; + Ok(LMModelOutput { + lm_logits: base_model_output.decoder_output, + cache: Cache::T5Cache(base_model_output.next_cache), + }) + } fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option { - Some(self.get_model().encode(input_ids, attention_mask)) + Some(self.model.encode(input_ids, attention_mask)) } fn prepare_inputs_for_generation<'a>( @@ -1059,4 +981,4 @@ impl PrivateLanguageGenerator } } -impl LanguageGenerator for T5Generator {} +impl LanguageGenerator for T5Generator {} diff --git a/src/xlnet/mod.rs b/src/xlnet/mod.rs index aa55d6dfe..1541838fe 100644 --- a/src/xlnet/mod.rs +++ b/src/xlnet/mod.rs @@ -2,7 +2,7 @@ //! //! Implementation of the XLNet language model ([Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) Yang, Dai, Yang, Carbonell, Salakhutdinov, Le, 2019). //! The base model is implemented in the `xlnet_model::XLNetModel` struct. Several language model heads have also been implemented, including: -//! - Language generation: `xlnet_model::XLNetLMHeadModel` implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information) +//! - Language generation: `xlnet_model::XLNetLMHeadModel` implementing the common `generation_utils::LanguageGenerator` trait shared between the models used for generation (see `pipelines` for more information) //! - Multiple choices: `xlnet_model:XLNetForMultipleChoice` //! - Question answering: `xlnet_model::XLNetForQuestionAnswering` //! - Sequence classification: `xlnet_model::XLNetForSequenceClassification` diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index c62a5d2e0..fe5f386f1 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -19,14 +19,10 @@ use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, }; -use crate::pipelines::generation_utils::{ - Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, -}; +use crate::pipelines::generation_utils::{Cache, GenerateConfig, LMModelOutput, LanguageGenerator}; use crate::xlnet::attention::LayerState; use crate::xlnet::encoder::XLNetLayer; use crate::{Config, RustBertError}; -use rust_tokenizers::tokenizer::XLNetTokenizer; -use rust_tokenizers::vocab::XLNetVocab; use serde::{Deserialize, Serialize}; use std::borrow::{Borrow, BorrowMut}; use std::collections::HashMap; @@ -791,102 +787,6 @@ impl XLNetLMHeadModel { } } -impl LMHeadModel for XLNetLMHeadModel { - /// Forward pass through the model - /// - /// # Arguments - /// - /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided. - /// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. - /// * `perm_mask` - Optional tensor of shape (*batch size*, *sequence_length*, *sequence_length*). Mask to indicate the attention pattern for each input token (only used for pre-training over permutations, rather than simple token masking). - /// * `target_mapping ` - Optional tensor of shape (*batch size*, *num_tokens*, *sequence_length*) indicating the position of the masked words to predict. - /// * `token_type_ids` - Optional tensor (*batch size*, *sequence_length*) indicating the sentence ID of the token (0: first sentence, 1: second sentence). - /// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. - /// * `old_layer_states` - Optional vector of length `num_layers` containing optional `LayerStates` containing the last calculated content for the attention layers. This avoids recomputing attention weights at past positions and speeds up decoding. - /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. - /// - /// # Returns - /// - /// * `LMModelOutput` containing: - /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position - /// - `cache` - `XLNetCache` made of `Option>>` of length *n_layers* and shape (*past_sequence_length*, *batch size*, *hidden_size*) containing the previous content - /// - /// # Example - /// - /// ```no_run - /// # use tch::{nn, Device, Tensor, no_grad, Kind}; - /// # use rust_bert::Config; - /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Double}; - /// use rust_bert::xlnet::{XLNetConfig, XLNetLMHeadModel}; - /// # let config_path = Path::new("path/to/config.json"); - /// # let vocab_path = Path::new("path/to/vocab.txt"); - /// # let device = Device::Cpu; - /// # let vs = nn::VarStore::new(device); - /// # let config = XLNetConfig::from_file(config_path); - /// # let xlnet_model: XLNetLMHeadModel = XLNetLMHeadModel::new(&vs.root(), &config); - /// let (batch_size, sequence_length) = (64, 128); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); - /// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); - /// let target_mapping = Tensor::zeros(&[64, 1, 128], (Kind::Float, device)); - /// let _ = target_mapping.narrow(2, 3, 1).fill_(1.0); - /// - /// let model_output = no_grad(|| { - /// xlnet_model.forward_t( - /// Some(&input_tensor), - /// Some(&attention_mask), - /// None, - /// Some(&target_mapping), - /// None, - /// None, - /// None, - /// false, - /// ) - /// }); - /// ``` - fn forward_t( - &self, - input_ids: Option<&Tensor>, - layer_past: Cache, - attention_mask: Option<&Tensor>, - _token_type_ids: Option<&Tensor>, - _position_ids: Option<&Tensor>, - _input_embeds: Option<&Tensor>, - _encoder_outputs: Option<&Tensor>, - decoder_input_ids: Option<&Tensor>, - train: bool, - ) -> Result { - match layer_past { - Cache::XLNetCache(layer_past) => self.forward_t( - input_ids, - None, - layer_past, - attention_mask, - // For XLNet the decoder_input_ids are used as a placeholder for the target mapping - decoder_input_ids, - None, - None, - train, - ), - Cache::None => self.forward_t( - input_ids, - None, - None, - attention_mask, - // For XLNet the decoder_input_ids are used as a placeholder for the target mapping - decoder_input_ids, - None, - None, - train, - ), - _ => Err(RustBertError::ValueError( - "Cache not compatible with XLNet Model".into(), - )), - } - } -} - /// # XLNetForSequenceClassification /// XLNet model with a classification head for sequence classification tasks /// It is made of the following blocks: @@ -1684,10 +1584,7 @@ impl XLNetGenerator { } } -impl PrivateLanguageGenerator for XLNetGenerator { - fn get_model(&self) -> &XLNetLMHeadModel { - &self.model - } +impl PrivateLanguageGenerator for XLNetGenerator { fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } @@ -1723,6 +1620,47 @@ impl PrivateLanguageGenerator for self.max_position_embeddings } + fn forward_t( + &self, + input_ids: Option<&Tensor>, + layer_past: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + _encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + match layer_past { + Cache::XLNetCache(layer_past) => self.model.forward_t( + input_ids, + None, + layer_past, + attention_mask, + // For XLNet the decoder_input_ids are used as a placeholder for the target mapping + decoder_input_ids, + None, + None, + train, + ), + Cache::None => self.model.forward_t( + input_ids, + None, + None, + attention_mask, + // For XLNet the decoder_input_ids are used as a placeholder for the target mapping + decoder_input_ids, + None, + None, + train, + ), + _ => Err(RustBertError::ValueError( + "Cache not compatible with XLNet Model".into(), + )), + } + } + fn prepare_inputs_for_generation<'a>( &self, input_ids: Tensor, @@ -1842,4 +1780,4 @@ impl PrivateLanguageGenerator for } } -impl LanguageGenerator for XLNetGenerator {} +impl LanguageGenerator for XLNetGenerator {} diff --git a/tests/distilgpt2.rs b/tests/distilgpt2.rs index 1ca7096bb..2727d942a 100644 --- a/tests/distilgpt2.rs +++ b/tests/distilgpt2.rs @@ -2,7 +2,7 @@ use rust_bert::gpt2::{ GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources, }; -use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; +use rust_bert::pipelines::generation_utils::Cache; use rust_bert::resources::{RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy}; @@ -61,17 +61,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { // Forward pass let model_output = gpt2_model - .forward_t( - Some(&input_tensor), - Cache::None, - None, - None, - None, - None, - None, - None, - false, - ) + .forward_t(Some(&input_tensor), None, None, None, None, None, false) .unwrap(); let next_word_id = model_output diff --git a/tests/gpt2.rs b/tests/gpt2.rs index 143ccd2fa..b820348f3 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -7,7 +7,7 @@ use rust_bert::pipelines::conversation::{ ConversationConfig, ConversationManager, ConversationModel, }; use rust_bert::pipelines::generation_utils::{ - Cache, GenerateConfig, GenerateOptions, LMHeadModel, LanguageGenerator, + Cache, GenerateConfig, GenerateOptions, LanguageGenerator, }; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::resources::{RemoteResource, ResourceProvider}; @@ -59,17 +59,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> { // Forward pass let model_output = gpt2_model - .forward_t( - Some(&input_tensor), - Cache::None, - None, - None, - None, - None, - None, - None, - false, - ) + .forward_t(Some(&input_tensor), None, None, None, None, None, false) .unwrap(); let next_word_id = model_output diff --git a/tests/gpt_j.rs b/tests/gpt_j.rs index a39ebbcaa..228ad9d69 100644 --- a/tests/gpt_j.rs +++ b/tests/gpt_j.rs @@ -2,7 +2,7 @@ use rust_bert::gpt_j::{ GptJConfig, GptJConfigResources, GptJLMHeadModel, GptJMergesResources, GptJModelResources, GptJVocabResources, }; -use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; +use rust_bert::pipelines::generation_utils::Cache; use rust_bert::resources::{RemoteResource, ResourceProvider}; use rust_bert::Config; use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer}; diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index 0e12f5d53..f62d1d74f 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -3,7 +3,7 @@ use rust_bert::openai_gpt::{ OpenAiGptModelResources, OpenAiGptVocabResources, }; use rust_bert::pipelines::common::ModelType; -use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; +use rust_bert::pipelines::generation_utils::Cache; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; use rust_bert::resources::{RemoteResource, ResourceProvider}; use rust_bert::Config;