From d7e9c036940aefebe81c083841e4ff3bad36bd56 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 12 Feb 2023 16:18:20 +0000 Subject: [PATCH] Long t5 implementation (#333) * LongT5 config implementation * LongT5 WiP: utility functions 1 * LongT5 WiP: utility functions (2) * LongT5 WiP: utility functions (3) * LongT5 WiP: utility functions (4) * made T5 FF activations generic, expose T5 modules to crate * Longt% local attention WIP * LongT5 local attention * LongT5 global attention WIP * LongT5 global attention * LongT5 attention modules (WIP) * align LongT5 position bias with T5 * Addition of LongT5Block * LongT5Stack WiP * LongT5Stack implementation * LongT5Model implementation * LongT5ForConditionalGeneration implementation * Addition of LongT5Generator, inclusion in pipelines * LongT5 attention fixes * Fix MIN/MAX dtype computation, mask for longt5 * Updated min/max and infinity computation across models * GlobalTransient attention fixes * Updated changelog, readme, tests, clippy --- .github/workflows/continuous-integration.yml | 1 + CHANGELOG.md | 6 + README.md | 1 + src/bart/bart_model.rs | 9 +- src/common/kind.rs | 19 + src/deberta/deberta_model.rs | 4 +- src/lib.rs | 4 + src/longt5/attention.rs | 807 +++++++++++++++++ src/longt5/encoder.rs | 457 ++++++++++ src/longt5/layer_norm.rs | 15 + src/longt5/longt5_model.rs | 894 +++++++++++++++++++ src/longt5/mod.rs | 59 ++ src/pipelines/common.rs | 10 +- src/pipelines/generation_utils.rs | 1 + src/pipelines/summarization.rs | 12 + src/prophetnet/decoder.rs | 6 +- src/t5/attention.rs | 67 +- src/t5/encoder.rs | 59 +- src/t5/mod.rs | 4 + src/t5/t5_model.rs | 4 +- tests/deberta_v2.rs | 2 +- tests/longt5.rs | 64 ++ utils/convert_model.py | 15 +- 23 files changed, 2444 insertions(+), 76 deletions(-) create mode 100644 src/longt5/attention.rs create mode 100644 src/longt5/encoder.rs create mode 100644 src/longt5/layer_norm.rs create mode 100644 src/longt5/longt5_model.rs create mode 100644 src/longt5/mod.rs create mode 100644 tests/longt5.rs diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 18d94b072..e9ba0e2b1 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -130,6 +130,7 @@ jobs: command: test args: --package rust-bert --test sentence_embeddings + --test longt5 convert-model: name: Model conversion test diff --git a/CHANGELOG.md b/CHANGELOG.md index 00b0c8a5b..cf7fb482e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## Added +- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights. + ## Changed - Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer. +## Fixed +- MIN/MAX computation for float-like (was set to infinity instead of min/max) + ## [0.20.0] - 2023-01-21 ## Added - Addition of All-MiniLM-L6-V2 model weights diff --git a/README.md b/README.md index cb2ca399a..622e74891 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ M2M100| | | |✅ | | | | | Electra | |✅| | | | |✅| | ALBERT |✅|✅|✅| | | |✅| ✅ | T5 | | | |✅ |✅|✅| | ✅ | +LongT5 | | | |✅ |✅|| | | XLNet|✅|✅|✅|✅ | | |✅| | Reformer|✅| |✅|✅ | | |✅| | ProphetNet| | | |✅ |✅ | | | | diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 8bde2d27a..1b54c9656 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -16,7 +16,7 @@ use crate::bart::decoder::BartDecoder; use crate::bart::encoder::BartEncoder; use crate::common::activations::Activation; use crate::common::dropout::Dropout; -use crate::common::kind::get_negative_infinity; +use crate::common::kind::get_min; use crate::pipelines::common::{ModelType, TokenizerOption}; use crate::pipelines::generation_utils::private_generation_utils::{ PreparedInput, PrivateLanguageGenerator, @@ -273,7 +273,7 @@ pub(crate) fn _make_causal_mask( let mut mask = Tensor::full( &[target_length, target_length], - get_negative_infinity(dtype).unwrap(), + get_min(dtype).unwrap(), (dtype, device), ); let mask_cond = Tensor::arange(target_length, (dtype, device)); @@ -311,10 +311,7 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option, dtype: Kin .expand(&[batch_size, 1, target_length, source_length], true) .totype(dtype); let inverted_mask: Tensor = 1 - expanded_mask; - inverted_mask.masked_fill( - &inverted_mask.to_kind(Kind::Bool), - get_negative_infinity(dtype).unwrap(), - ) + inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap()) } pub(crate) fn _prepare_decoder_attention_mask( diff --git a/src/common/kind.rs b/src/common/kind.rs index 60b3dab3e..94b657683 100644 --- a/src/common/kind.rs +++ b/src/common/kind.rs @@ -38,3 +38,22 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result } }) } + +pub(crate) fn get_min(kind: Kind) -> Result { + Ok(match kind { + Kind::Uint8 => Scalar::int(u8::MIN.into()), + Kind::Int8 => Scalar::int(i8::MIN.into()), + Kind::Int16 => Scalar::int(i16::MIN.into()), + Kind::Int => Scalar::int(i32::MIN.into()), + Kind::Int64 => Scalar::int(i64::MIN), + Kind::Half => Scalar::float(half::f16::MIN.into()), + Kind::Float => Scalar::float(f32::MIN.into()), + Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()), + Kind::Double => Scalar::float(f64::MIN), + _ => { + return Err(RustBertError::ValueError(format!( + "Type not supported: attempted to get min for {kind:?}", + ))) + } + }) +} diff --git a/src/deberta/deberta_model.rs b/src/deberta/deberta_model.rs index 00f8a9cf9..680a24b17 100644 --- a/src/deberta/deberta_model.rs +++ b/src/deberta/deberta_model.rs @@ -16,7 +16,7 @@ use crate::bert::{ use crate::common::activations::TensorFunction; use crate::common::dropout::{Dropout, XDropout}; use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair; -use crate::common::kind::get_negative_infinity; +use crate::common::kind::get_min; use crate::deberta::embeddings::DebertaEmbeddings; use crate::deberta::encoder::{DebertaEncoder, DebertaEncoderOutput}; use crate::{Activation, Config, RustBertError}; @@ -264,7 +264,7 @@ impl Config for DebertaConfig {} pub fn x_softmax(input: &Tensor, mask: &Tensor, dim: i64) -> Tensor { let inverse_mask = ((1 - mask) as Tensor).to_kind(Kind::Bool); input - .masked_fill(&inverse_mask, get_negative_infinity(input.kind()).unwrap()) + .masked_fill(&inverse_mask, get_min(input.kind()).unwrap()) .softmax(dim, input.kind()) .masked_fill(&inverse_mask, 0.0) } diff --git a/src/lib.rs b/src/lib.rs index c53a41c6d..a077b300a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,6 +68,7 @@ //!Electra | |✅| | | | |✅| | //!ALBERT |✅|✅|✅| | | |✅| ✅ | //!T5 | | | |✅ |✅|✅| | ✅ | +//!LongT5 | | | |✅ |✅| | | | //!XLNet|✅|✅|✅|✅ | | |✅| | //!Reformer|✅| |✅|✅ | | |✅| | //!ProphetNet| | | |✅ |✅ | | | | @@ -695,6 +696,8 @@ // These are used abundantly in this code #![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)] +extern crate core; + pub mod albert; pub mod bart; pub mod bert; @@ -707,6 +710,7 @@ pub mod fnet; pub mod gpt2; pub mod gpt_neo; pub mod longformer; +pub mod longt5; pub mod m2m_100; pub mod marian; pub mod mbart; diff --git a/src/longt5/attention.rs b/src/longt5/attention.rs new file mode 100644 index 000000000..efb2e2d2e --- /dev/null +++ b/src/longt5/attention.rs @@ -0,0 +1,807 @@ +// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::dropout::Dropout; +use crate::longt5::layer_norm::LongT5LayerNorm; +use crate::longt5::LongT5Config; +use crate::t5::{ + get_relative_position_bucket, LayerState as T5layerState, T5Attention, T5LayerCrossAttention, +}; +use std::borrow::Borrow; +use tch::nn::LinearConfig; +use tch::{nn, Device, IndexOp, Kind, Tensor}; + +pub type LongT5Attention = T5Attention; +pub type LongT5LayerCrossAttention = T5LayerCrossAttention; +pub type LayerState = T5layerState; + +fn pad_to_multiple(x: &Tensor, block_length: i64, dim: usize, pad_value: f64) -> Tensor { + let mut x_size = x.size(); + let pad_length = (-x_size[dim]).rem_euclid(block_length); + + if x_size.iter().any(|&el| el == 0) { + x_size[dim] += pad_length; + Tensor::zeros(x_size.as_slice(), (x.kind(), x.device())) + } else { + let mut pad = vec![0i64; 2 * x.dim()]; + pad[2 * dim] = pad_length; + pad.reverse(); + x.pad(pad.as_slice(), "constant", pad_value) + } +} + +fn split_into_blocks(x: &Tensor, block_length: i64, dim: usize) -> Tensor { + let x_size = x.size(); + let padded_x = if x_size[dim] % block_length != 0 { + Some(pad_to_multiple(x, block_length, dim, 0f64)) + } else { + None + }; + let x = padded_x.as_ref().unwrap_or(x); + let mut x_size = x.size(); + let num_blocks = x_size[dim] / block_length; + x_size.remove(dim); + x_size.insert(dim, block_length); + x_size.insert(dim, num_blocks); + if x_size.iter().any(|&el| el == 0) { + Tensor::empty(x_size.as_slice(), (x.kind(), x.device())) + } else { + x.reshape(x_size.as_slice()) + } +} + +fn concatenate_3_blocks( + x: &Tensor, + block_dim: usize, + sequence_dim: i64, + pad_value: Option, +) -> Tensor { + let x_size = x.size(); + let num_blocks = x_size[block_dim]; + let mut pad = vec![0i64; 2 * x.dim()]; + pad[2 * block_dim] = 1; + pad[2 * block_dim + 1] = 1; + pad.reverse(); + let x = x.pad(pad.as_slice(), "constant", pad_value.unwrap_or(0f64)); + let mut block_list: Vec = Vec::with_capacity(3); + for i in 0..3 { + block_list.push(x.narrow(block_dim as i64, i, num_blocks)); + } + Tensor::cat(block_list.as_slice(), sequence_dim) +} + +fn make_3blocks_relative_position_ids(block_length: i64, device: Device) -> Tensor { + let position_ids = Tensor::arange(3 * block_length, (Kind::Int, device)); + let center_position_ids = position_ids.i(block_length..2 * block_length); + position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1) +} + +fn mask_local_attention_mask(local_attention_mask: &Tensor, block_length: i64) -> Tensor { + let relative_position_ids = + make_3blocks_relative_position_ids(block_length, local_attention_mask.device()); + let locality_mask = relative_position_ids + .abs() + .lt(block_length) + .unsqueeze(0) + .unsqueeze(0); + local_attention_mask.logical_and(&locality_mask) +} + +pub(crate) fn get_local_attention_mask(attention_mask: &Tensor, block_length: i64) -> Tensor { + let blocked_attention_mask = split_into_blocks(attention_mask, block_length, 1); + let three_blocked_attention_mask = concatenate_3_blocks(&blocked_attention_mask, 1, 2, None); + + let blocked_attention_mask = blocked_attention_mask.unsqueeze(-1); + let three_blocked_attention_mask = three_blocked_attention_mask.unsqueeze(-2); + + let local_attention_mask = mask_local_attention_mask( + &blocked_attention_mask.logical_and(&three_blocked_attention_mask), + block_length, + ); + local_attention_mask.unsqueeze(1) +} + +fn make_global_fixed_block_ids( + attention_mask: &Tensor, + global_block_size: i64, +) -> (Tensor, Tensor) { + let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {unreachable!()}; + + let handle_orphan_tokens = |block_ids: Tensor| -> Tensor { + let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device())) + .remainder(global_block_size) + .eq(global_block_size - 1); + let true_block_ends = block_ends.logical_and(&block_ids.ge(0)); + let full_blocks = true_block_ends + .sum_dim_intlist([-1].as_slice(), false, block_ids.kind()) + .unsqueeze(-1) + - 1; + block_ids.where_self(&block_ids.lt_tensor(&full_blocks), &full_blocks) + }; + + let fixed_block_mask = attention_mask.ones_like() / global_block_size; + let fixed_block_mask = fixed_block_mask.cumsum(1, fixed_block_mask.kind()) - fixed_block_mask; + let mask = attention_mask + .ones_like() + .where_scalarother(&attention_mask.not_equal(0.0), -1000.0); + + let mut global_block_ids = (mask + fixed_block_mask - 1.0).floor(); + global_block_ids = global_block_ids.where_scalarother(&global_block_ids.gt(-1.0), -1.0); + global_block_ids = global_block_ids * attention_mask + attention_mask - 1; + global_block_ids = handle_orphan_tokens(global_block_ids); + let num_globals = seq_length / global_block_size; + let sequence_block_ids_max = if num_globals > 0 { + global_block_ids + .max_dim(-1, false) + .0 + .repeat(&[num_globals, 1]) + .transpose(0, 1) + } else { + Tensor::zeros( + &[batch_size, 0], + (global_block_ids.kind(), global_block_ids.device()), + ) + }; + let global_segment_ids = Tensor::ones( + &[batch_size, num_globals], + (attention_mask.kind(), attention_mask.device()), + ) + .cumsum(-1, attention_mask.kind()) + - 1; + let global_segment_ids = global_segment_ids + .ones_like() + .where_scalarother(&global_segment_ids.le_tensor(&sequence_block_ids_max), 0.0); + ( + global_block_ids.to_kind(Kind::Int), + global_segment_ids.to_kind(Kind::Int), + ) +} + +fn make_side_relative_position_ids(attention_mask: &Tensor, global_block_size: i64) -> Tensor { + let (block_ids, global_segment_ids) = + make_global_fixed_block_ids(attention_mask, global_block_size); + let global_seq_length = *global_segment_ids.size().last().unwrap(); + let global_positions = Tensor::arange(global_seq_length, (Kind::Int64, block_ids.device())); + global_positions - block_ids.unsqueeze(-1) +} + +fn create_global_aggregates( + hidden_states: &Tensor, + block_ids: &Tensor, + global_seq_length: i64, +) -> Tensor { + let block_ids = block_ids.where_scalarother(&block_ids.ge(0), global_seq_length); + let one_hot_block_ids = block_ids + .to_kind(Kind::Int64) + .one_hot(global_seq_length + 1); + let one_hot_block_ids = one_hot_block_ids.narrow(2, 0, one_hot_block_ids.size()[2] - 1); + Tensor::einsum( + "...nd,...ng->...gd", + &[ + hidden_states, + &one_hot_block_ids.to_kind(hidden_states.kind()), + ], + None, + ) +} + +fn compute_bias( + block_length: i64, + relative_attention_bias: &nn::Embedding, + is_decoder: bool, + relative_attention_num_buckets: i64, + relative_attention_max_distance: i64, +) -> Tensor { + let device = relative_attention_bias.ws.device(); + let memory_position = Tensor::arange(3 * block_length, (Kind::Int64, device)); + let context_position = memory_position.narrow(0, block_length, block_length); + let relative_position = memory_position.unsqueeze(0) - context_position.unsqueeze(-1); + + let rp_bucket = get_relative_position_bucket( + &relative_position, + !is_decoder, + relative_attention_num_buckets, + relative_attention_max_distance, + ); + rp_bucket + .apply(relative_attention_bias) + .permute(&[2, 0, 1]) + .unsqueeze(0) + .unsqueeze(0) +} + +pub struct LongT5LocalAttention { + is_decoder: bool, + has_relative_attention_bias: bool, + relative_attention_num_buckets: i64, + relative_attention_max_distance: i64, + key_value_proj_dim: i64, + n_heads: i64, + block_length: i64, + dropout: Dropout, + inner_dim: i64, + output_attentions: bool, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + output: nn::Linear, + relative_attention_bias: Option, +} + +impl LongT5LocalAttention { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + is_decoder: bool, + has_relative_attention_bias: bool, + ) -> LongT5LocalAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let linear_config = LinearConfig { + bias: false, + ..Default::default() + }; + + let block_length = config.local_radius + 1; + let key_value_proj_dim = config.d_kv; + + let inner_dim = config.num_heads * config.d_kv; + let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config); + let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config); + let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config); + let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config); + + let dropout = Dropout::new(config.dropout_rate); + let relative_attention_bias = if has_relative_attention_bias { + Some(nn::embedding( + p / "relative_attention_bias", + config.relative_attention_num_buckets, + config.num_heads, + Default::default(), + )) + } else { + None + }; + + LongT5LocalAttention { + is_decoder, + has_relative_attention_bias, + relative_attention_num_buckets: config.relative_attention_num_buckets, + relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128), + key_value_proj_dim, + n_heads: config.num_heads, + block_length, + dropout, + inner_dim, + output_attentions: config.output_attentions.unwrap_or(false), + query, + key, + value, + output, + relative_attention_bias, + } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + mask: Option<&Tensor>, + position_bias: Option<&Tensor>, + train: bool, + ) -> (Tensor, Option, Option) { + let input_size = hidden_states.size(); + let (batch_size, seq_length) = (input_size[0], input_size[1]); + + let shape = |states: &Tensor| -> Tensor { + states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim]) + }; + let unshape = |states: &Tensor| -> Tensor { + states.contiguous().view([batch_size, -1, self.inner_dim]) + }; + + let query_states = shape(&hidden_states.apply(&self.query)); + let key_states = shape(&hidden_states.apply(&self.key)); + let value_states = shape(&hidden_states.apply(&self.value)); + + let query_states = split_into_blocks(&query_states, self.block_length, 1); + let key_states = split_into_blocks(&key_states, self.block_length, 1); + let value_states = split_into_blocks(&value_states, self.block_length, 1); + + let key_states = concatenate_3_blocks(&key_states, 1, 2, None); + let value_states = concatenate_3_blocks(&value_states, 1, 2, None); + + let mut scores = Tensor::einsum("...qhd,...khd->...hqk", &[query_states, key_states], None); + let calc_position_bias = if position_bias.is_none() { + let mut position_bias = if !self.has_relative_attention_bias { + Tensor::zeros( + &[1, 1, self.n_heads, self.block_length, 3 * self.block_length], + (scores.kind(), scores.device()), + ) + } else { + compute_bias( + self.block_length, + self.relative_attention_bias.as_ref().unwrap(), + self.is_decoder, + self.relative_attention_num_buckets, + self.relative_attention_max_distance, + ) + }; + if let Some(mask) = mask { + let mask = mask.zeros_like().where_scalarother(&mask.gt(0), -1e10); + position_bias = position_bias + mask.transpose(1, 2); + } + Some(position_bias) + } else { + None + }; + let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap()); + scores += position_bias; + let attention_weights = scores + .to_kind(Kind::Float) + .softmax(-1, scores.kind()) + .apply_t(&self.dropout, train) + .to_kind(value_states.kind()); + let attention_output = unshape(&Tensor::einsum( + "...hqk,...khd->...qhd", + &[&attention_weights, &value_states], + None, + )) + .narrow(1, 0, seq_length) + .apply(&self.output); + + let attention_weights = if self.output_attentions { + Some(attention_weights) + } else { + None + }; + + let position_bias = if self.has_relative_attention_bias { + calc_position_bias + } else { + None + }; + (attention_output, position_bias, attention_weights) + } +} + +pub struct LongT5TransientGlobalAttention { + is_decoder: bool, + has_relative_attention_bias: bool, + relative_attention_num_buckets: i64, + relative_attention_max_distance: i64, + key_value_proj_dim: i64, + n_heads: i64, + block_length: i64, + global_block_size: i64, + dropout: Dropout, + inner_dim: i64, + output_attentions: bool, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + output: nn::Linear, + relative_attention_bias: Option, + global_relative_attention_bias: Option, + global_input_layer_norm: LongT5LayerNorm, +} + +impl LongT5TransientGlobalAttention { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + is_decoder: bool, + has_relative_attention_bias: bool, + ) -> LongT5TransientGlobalAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let linear_config = LinearConfig { + bias: false, + ..Default::default() + }; + + let block_length = config.local_radius + 1; + let global_block_size = config.global_block_size; + let key_value_proj_dim = config.d_kv; + + let inner_dim = config.num_heads * config.d_kv; + let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config); + let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config); + let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config); + let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config); + + let dropout = Dropout::new(config.dropout_rate); + let global_relative_attention_bias = if has_relative_attention_bias { + Some(nn::embedding( + p / "global_relative_attention_bias", + config.relative_attention_num_buckets, + config.num_heads, + Default::default(), + )) + } else { + None + }; + let relative_attention_bias = if has_relative_attention_bias { + Some(nn::embedding( + p / "relative_attention_bias", + config.relative_attention_num_buckets, + config.num_heads, + Default::default(), + )) + } else { + None + }; + let global_input_layer_norm = LongT5LayerNorm::new( + p / "global_input_layer_norm", + config.d_model, + config.layer_norm_epsilon, + ); + + LongT5TransientGlobalAttention { + is_decoder, + has_relative_attention_bias, + relative_attention_num_buckets: config.relative_attention_num_buckets, + relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128), + key_value_proj_dim, + n_heads: config.num_heads, + block_length, + global_block_size, + dropout, + inner_dim, + output_attentions: config.output_attentions.unwrap_or(false), + query, + key, + value, + output, + relative_attention_bias, + global_relative_attention_bias, + global_input_layer_norm, + } + } + + fn compute_side_bias(&self, mask: &Tensor, global_segment_ids: &Tensor) -> Tensor { + let side_attention_mask = mask + .unsqueeze(-1) + .eq_tensor(&global_segment_ids.unsqueeze(1)) + .unsqueeze(1); + + let attention_side_bias = side_attention_mask + .zeros_like() + .where_scalarother(&side_attention_mask.gt(0), -1e10); + + let side_relative_position = make_side_relative_position_ids(mask, self.global_block_size); + let side_relative_position_bucket = get_relative_position_bucket( + &side_relative_position, + !self.is_decoder, + self.relative_attention_num_buckets, + self.relative_attention_max_distance, + ); + let side_bias = side_relative_position_bucket + .apply(self.global_relative_attention_bias.as_ref().unwrap()) + .permute(&[0, 3, 1, 2]); + attention_side_bias + side_bias + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + mask: Option<&Tensor>, + position_bias: Option<&Tensor>, + train: bool, + ) -> (Tensor, Option, Option) { + let input_size = hidden_states.size(); + let (batch_size, seq_length) = (input_size[0], input_size[1]); + + let shape = |states: &Tensor| -> Tensor { + states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim]) + }; + let unshape = |states: &Tensor| -> Tensor { + states.contiguous().view([batch_size, -1, self.inner_dim]) + }; + let calc_mask = if mask.is_none() { + let mut mask_size = input_size; + let _ = mask_size.pop(); + Some(Tensor::ones( + mask_size.as_slice(), + (Kind::Bool, hidden_states.device()), + )) + } else { + None + }; + let (block_ids, global_segment_ids) = make_global_fixed_block_ids( + mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()), + self.global_block_size, + ); + let global_seq_length = *global_segment_ids.size().last().unwrap(); + let global_inputs = create_global_aggregates(hidden_states, &block_ids, global_seq_length) + .apply(&self.global_input_layer_norm); + + let query_states = shape(&hidden_states.apply(&self.query)); + let key_states = shape(&hidden_states.apply(&self.key)); + let value_states = shape(&hidden_states.apply(&self.value)); + + let side_key_states = shape(&global_inputs.apply(&self.key)); + let side_value_states = shape(&global_inputs.apply(&self.value)); + + let query_states = split_into_blocks(&query_states, self.block_length, 1); + let key_states = split_into_blocks(&key_states, self.block_length, 1); + let value_states = split_into_blocks(&value_states, self.block_length, 1); + + let key_states = concatenate_3_blocks(&key_states, 1, 2, None); + let value_states = concatenate_3_blocks(&value_states, 1, 2, None); + + let mut reps = vec![1; side_key_states.dim() + 1]; + reps[1] = key_states.size()[1]; + let side_key_states = side_key_states.unsqueeze(1).repeat(reps.as_slice()); + let side_value_states = side_value_states.unsqueeze(1).repeat(reps.as_slice()); + let key_states = Tensor::cat(&[key_states, side_key_states], 2); + let value_states = Tensor::cat(&[value_states, side_value_states], 2); + + let mut scores = Tensor::einsum("...qhd,...khd->...hqk", &[query_states, key_states], None); + let local_attention_mask = mask.map(|mask| { + let local_attention_mask = get_local_attention_mask(mask, self.block_length); + local_attention_mask + .zeros_like() + .where_scalarother(&local_attention_mask.gt(0), -1e10) + }); + + let calc_position_bias = if position_bias.is_none() { + let mut position_bias = if !self.has_relative_attention_bias { + Tensor::zeros( + &[1, 1, self.n_heads, self.block_length, 3 * self.block_length], + (scores.kind(), scores.device()), + ) + } else { + compute_bias( + self.block_length, + self.relative_attention_bias.as_ref().unwrap(), + self.is_decoder, + self.relative_attention_num_buckets, + self.relative_attention_max_distance, + ) + }; + if let Some(local_attention_mask) = local_attention_mask { + position_bias = position_bias + local_attention_mask.transpose(1, 2); + } + let calc_mask = if mask.is_none() { + Some(Tensor::ones( + &[batch_size, seq_length], + (global_segment_ids.kind(), global_segment_ids.device()), + )) + } else { + None + }; + let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()); + let side_position_bias = self.compute_side_bias(mask, &global_segment_ids); + let side_position_bias = split_into_blocks( + &side_position_bias, + self.block_length, + side_position_bias.dim() - 2, + ) + .transpose(1, 2); + let position_bias = Tensor::cat(&[position_bias, side_position_bias], -1); + + Some(position_bias) + } else { + None + }; + let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap()); + + scores += position_bias; + let attention_weights = scores + .to_kind(Kind::Float) + .softmax(-1, scores.kind()) + .apply_t(&self.dropout, train); + + let attention_output = unshape(&Tensor::einsum( + "...hqk,...khd->...qhd", + &[&attention_weights, &value_states], + None, + )) + .narrow(1, 0, seq_length) + .apply(&self.output); + + let attention_weights = if self.output_attentions { + Some(attention_weights) + } else { + None + }; + + let position_bias = if self.has_relative_attention_bias { + calc_position_bias + } else { + None + }; + (attention_output, position_bias, attention_weights) + } +} + +pub struct LongT5LayerSelfAttention { + self_attention: LongT5Attention, + layer_norm: LongT5LayerNorm, + dropout: Dropout, +} + +impl LongT5LayerSelfAttention { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + has_relative_attention_bias: bool, + is_decoder: bool, + store_cache: bool, + output_attentions: bool, + ) -> LongT5LayerSelfAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let self_attention = LongT5Attention::new( + p / "SelfAttention", + &config.into(), + is_decoder, + !is_decoder, + store_cache, + output_attentions, + has_relative_attention_bias, + ); + + let layer_norm = + LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon); + let dropout = Dropout::new(config.dropout_rate); + + LongT5LayerSelfAttention { + self_attention, + layer_norm, + dropout, + } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + attention_mask: Option<&Tensor>, + layer_state: Option, + train: bool, + ) -> (Tensor, Option, Option, Option) { + let norm_x = hidden_states.apply(&self.layer_norm); + + let (y, attention_weights, position_bias, layer_state) = self.self_attention.forward_t( + &norm_x, + None, + position_bias, + attention_mask, + layer_state, + None, + train, + ); + + let output = hidden_states + y.apply_t(&self.dropout, train); + + (output, attention_weights, position_bias, layer_state) + } +} + +pub struct LongT5LayerLocalSelfAttention { + local_self_attention: LongT5LocalAttention, + layer_norm: LongT5LayerNorm, + dropout: Dropout, +} + +impl LongT5LayerLocalSelfAttention { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + has_relative_attention_bias: bool, + is_decoder: bool, + ) -> LongT5LayerLocalSelfAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let local_self_attention = LongT5LocalAttention::new( + p / "LocalSelfAttention", + config, + is_decoder, + has_relative_attention_bias, + ); + + let layer_norm = + LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon); + let dropout = Dropout::new(config.dropout_rate); + + LongT5LayerLocalSelfAttention { + local_self_attention, + layer_norm, + dropout, + } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + attention_mask: Option<&Tensor>, + position_bias: Option<&Tensor>, + train: bool, + ) -> (Tensor, Option, Option) { + let normed_hidden_states = hidden_states.apply(&self.layer_norm); + + let (attention_output, position_bias, attention_weights) = self + .local_self_attention + .forward_t(&normed_hidden_states, attention_mask, position_bias, train); + + let output = hidden_states + attention_output.apply_t(&self.dropout, train); + + (output, position_bias, attention_weights) + } +} + +pub struct LongT5LayerTransientGlobalSelfAttention { + transient_global_sef_attention: LongT5TransientGlobalAttention, + layer_norm: LongT5LayerNorm, + dropout: Dropout, +} + +impl LongT5LayerTransientGlobalSelfAttention { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + has_relative_attention_bias: bool, + is_decoder: bool, + ) -> LongT5LayerTransientGlobalSelfAttention + where + P: Borrow>, + { + let p = p.borrow(); + + let transient_global_sef_attention = LongT5TransientGlobalAttention::new( + p / "TransientGlobalSelfAttention", + config, + is_decoder, + has_relative_attention_bias, + ); + + let layer_norm = + LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon); + let dropout = Dropout::new(config.dropout_rate); + + LongT5LayerTransientGlobalSelfAttention { + transient_global_sef_attention, + layer_norm, + dropout, + } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + attention_mask: Option<&Tensor>, + position_bias: Option<&Tensor>, + train: bool, + ) -> (Tensor, Option, Option) { + let normed_hidden_states = hidden_states.apply(&self.layer_norm); + let (attention_output, position_bias, attention_weights) = self + .transient_global_sef_attention + .forward_t(&normed_hidden_states, attention_mask, position_bias, train); + + let output = hidden_states + attention_output.apply_t(&self.dropout, train); + + (output, position_bias, attention_weights) + } +} diff --git a/src/longt5/encoder.rs b/src/longt5/encoder.rs new file mode 100644 index 000000000..463b6f922 --- /dev/null +++ b/src/longt5/encoder.rs @@ -0,0 +1,457 @@ +// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::dropout::Dropout; +use crate::common::embeddings::process_ids_embeddings_pair; +use crate::common::kind::get_min; +use crate::longt5::attention::{ + get_local_attention_mask, LayerState, LongT5LayerCrossAttention, LongT5LayerLocalSelfAttention, + LongT5LayerSelfAttention, LongT5LayerTransientGlobalSelfAttention, +}; +use crate::longt5::layer_norm::LongT5LayerNorm; +use crate::longt5::longt5_model::EncoderAttentionType; +use crate::longt5::LongT5Config; +use crate::t5::{T5Block, T5BlockOutput, T5LayerFF, T5StackOutput}; +use crate::RustBertError; +use std::borrow::{Borrow, BorrowMut}; +use tch::{nn, Kind, Tensor}; + +pub type LongT5LayerFF = T5LayerFF; + +enum LongT5AttentionLayer { + SelfAttention(LongT5LayerSelfAttention), + Local(LongT5LayerLocalSelfAttention), + Global(LongT5LayerTransientGlobalSelfAttention), +} + +impl LongT5AttentionLayer { + pub fn forward_t( + &self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + attention_mask: Option<&Tensor>, + layer_state: Option, + train: bool, + ) -> (Tensor, Option, Option, Option) { + match self { + LongT5AttentionLayer::SelfAttention(ref layer) => layer.forward_t( + hidden_states, + position_bias, + attention_mask, + layer_state, + train, + ), + LongT5AttentionLayer::Local(ref layer) => { + let (output, position_bias, attention_weights) = + layer.forward_t(hidden_states, attention_mask, position_bias, train); + (output, attention_weights, position_bias, None) + } + LongT5AttentionLayer::Global(ref layer) => { + let (output, position_bias, attention_weights) = + layer.forward_t(hidden_states, attention_mask, position_bias, train); + (output, attention_weights, position_bias, None) + } + } + } +} + +pub struct LongT5Block { + attention_layer: LongT5AttentionLayer, + cross_attention: Option, + ff_layer: LongT5LayerFF, +} + +impl LongT5Block { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + has_relative_attention_bias: bool, + is_decoder: bool, + store_cache: bool, + output_attentions: bool, + ) -> LongT5Block + where + P: Borrow>, + { + let p = p.borrow() / "layer"; + let mut module_index = 0; + + let attention_layer = if is_decoder { + LongT5AttentionLayer::SelfAttention(LongT5LayerSelfAttention::new( + &p / module_index, + config, + has_relative_attention_bias, + is_decoder, + store_cache, + output_attentions, + )) + } else { + match config.encoder_attention_type { + Some(EncoderAttentionType::Local) | None => { + LongT5AttentionLayer::Local(LongT5LayerLocalSelfAttention::new( + &p / module_index, + config, + has_relative_attention_bias, + is_decoder, + )) + } + Some(EncoderAttentionType::TransientGlobal) => { + LongT5AttentionLayer::Global(LongT5LayerTransientGlobalSelfAttention::new( + &p / module_index, + config, + has_relative_attention_bias, + is_decoder, + )) + } + } + }; + + let cross_attention = if is_decoder { + module_index += 1; + Some(LongT5LayerCrossAttention::new( + &p / module_index, + &config.into(), + false, + is_decoder, + store_cache, + output_attentions, + )) + } else { + None + }; + module_index += 1; + + let ff_layer = LongT5LayerFF::new(&p / module_index, &config.into()); + + LongT5Block { + attention_layer, + cross_attention, + ff_layer, + } + } + + pub fn forward_t( + &self, + hidden_states: &Tensor, + attention_mask: Option<&Tensor>, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + encoder_decoder_position_bias: Option<&Tensor>, + mut layer_states: (Option, Option), + train: bool, + ) -> LongT5BlockOutput { + let ( + mut hidden_states, + self_attention_weights, + self_attention_position_bias, + self_attention_layer_past, + ) = self.attention_layer.forward_t( + hidden_states, + position_bias, + attention_mask, + layer_states.0, + train, + ); + + hidden_states = T5Block::clamp_hidden_states(hidden_states); + + let ( + mut hidden_states, + cross_attention_weights, + cross_attention_position_bias, + cross_attention_layer_past, + ) = if self.cross_attention.is_some() & encoder_hidden_states.is_some() { + let query_length = self_attention_layer_past + .as_ref() + .map(|value| value.prev_key.size()[2]); + self.cross_attention.as_ref().unwrap().forward_t( + &hidden_states, + encoder_hidden_states, + encoder_decoder_position_bias, + encoder_attention_mask, + layer_states.1, + query_length, + train, + ) + } else { + (hidden_states, None, None, None) + }; + + hidden_states = T5Block::clamp_hidden_states(hidden_states); + + layer_states = (self_attention_layer_past, cross_attention_layer_past); + let mut hidden_states = self.ff_layer.forward_t(&hidden_states, train); + + hidden_states = T5Block::clamp_hidden_states(hidden_states); + + LongT5BlockOutput { + hidden_states, + self_attention_weights, + cross_attention_weights, + self_attention_position_bias, + cross_attention_position_bias, + cache: layer_states, + } + } +} + +pub struct LongT5Stack { + blocks: Vec, + final_layer_norm: LongT5LayerNorm, + dropout: Dropout, + output_attentions: bool, + output_hidden_states: bool, + is_decoder: bool, + store_cache: bool, + encoder_attention_type: EncoderAttentionType, + block_length: i64, +} + +impl LongT5Stack { + pub fn new<'p, P>( + p: P, + config: &LongT5Config, + is_decoder: bool, + store_cache: bool, + output_attentions: bool, + output_hidden_states: bool, + ) -> LongT5Stack + where + P: Borrow>, + { + let p = p.borrow(); + let dropout = Dropout::new(config.dropout_rate); + + let mut blocks: Vec = vec![]; + let p_layers = p / "block"; + for layer_index in 0..config.num_layers { + blocks.push(LongT5Block::new( + &p_layers / layer_index, + config, + layer_index == 0, + is_decoder, + store_cache, + output_attentions, + )); + } + + let final_layer_norm = LongT5LayerNorm::new( + p / "final_layer_norm", + config.d_model, + config.layer_norm_epsilon, + ); + + let encoder_attention_type = config + .encoder_attention_type + .unwrap_or(EncoderAttentionType::Local); + + let block_length = config.local_radius + 1; + + LongT5Stack { + blocks, + final_layer_norm, + dropout, + output_attentions, + output_hidden_states, + is_decoder, + store_cache, + encoder_attention_type, + block_length, + } + } + + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + input_embeds: Option<&Tensor>, + embeddings: &nn::Embedding, + old_layer_states: Option, Option)>>, + train: bool, + ) -> Result { + let (calc_input_embeddings, input_shape, _) = + process_ids_embeddings_pair(input_ids, input_embeds, embeddings)?; + let input_embeddings = + input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap()); + + let (batch_size, sequence_length) = (input_shape[0], input_shape[1]); + + let mask_seq_length = if old_layer_states.is_some() { + if old_layer_states.as_ref().unwrap()[0].0.is_some() { + old_layer_states.as_ref().unwrap()[0] + .0 + .as_ref() + .unwrap() + .prev_key + .size()[2] + + sequence_length + } else { + sequence_length + } + } else { + sequence_length + }; + + let calculated_attention_mask = if attention_mask.is_none() { + Some(Tensor::ones( + &[batch_size, mask_seq_length], + (Kind::Int64, input_embeddings.device()), + )) + } else { + None + }; + let attention_mask = + attention_mask.unwrap_or_else(|| calculated_attention_mask.as_ref().unwrap()); + + let extended_attention_mask = if self.is_decoder { + let extended_attention_mask = match attention_mask.dim() { + 3 => attention_mask.unsqueeze(1), + 2 => { + if self.is_decoder { + let seq_ids = Tensor::arange( + sequence_length, + (input_embeddings.kind(), input_embeddings.device()), + ); + let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[ + batch_size, + sequence_length, + 1, + ]); + let causal_mask = + causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1)); + causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1) + } else { + attention_mask.unsqueeze(1).unsqueeze(1) + } + } + _ => { + return Err(RustBertError::ValueError( + "Invalid attention mask dimension, must be 2 or 3".into(), + )); + } + }; + Some( + (extended_attention_mask.ones_like() - extended_attention_mask) + .to_kind(input_embeddings.kind()) + * get_min(input_embeddings.kind()).unwrap(), + ) + } else if let EncoderAttentionType::Local = self.encoder_attention_type { + Some(get_local_attention_mask(attention_mask, self.block_length)) + } else { + None + }; + let extended_attention_mask = extended_attention_mask.as_ref().unwrap_or(attention_mask); + + let encoder_extended_attention_mask = if self.is_decoder & encoder_hidden_states.is_some() { + let new_shape = &encoder_hidden_states.as_ref().unwrap().size()[..2]; + let calculated_encoder_attention_mask = if encoder_attention_mask.is_none() { + Some(Tensor::ones( + &[batch_size, new_shape[1]], + (Kind::Int64, input_embeddings.device()), + )) + } else { + None + }; + let encoder_attention_mask = encoder_attention_mask + .unwrap_or_else(|| calculated_encoder_attention_mask.as_ref().unwrap()); + + let mut encoder_extended_attention_mask = + encoder_attention_mask.to_kind(input_embeddings.kind()); + if encoder_extended_attention_mask.dim() == 3 { + encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze_(1); + } else if encoder_extended_attention_mask.dim() == 2 { + encoder_extended_attention_mask = + encoder_extended_attention_mask.unsqueeze_(1).unsqueeze_(1); + }; + Some( + (encoder_extended_attention_mask.ones_like() - encoder_extended_attention_mask) + * get_min(input_embeddings.kind()).unwrap(), + ) + } else { + None + }; + + let mut all_hidden_states: Option> = if self.output_hidden_states { + Some(Vec::with_capacity(self.blocks.len())) + } else { + None + }; + let mut all_attentions: Option> = if self.output_attentions { + Some(Vec::with_capacity(self.blocks.len())) + } else { + None + }; + let mut next_cache: Option, Option)>> = + if self.store_cache { + if old_layer_states.is_some() { + old_layer_states + } else { + Some(vec![(None, None); self.blocks.len()]) + } + } else { + None + }; + let mut position_bias = None; + let mut encoder_decoder_position_bias = None; + let mut attention_weights: Option; + let mut hidden_state = input_embeddings.apply_t(&self.dropout, train); + + for (layer_idx, layer) in self.blocks.iter().enumerate() { + let layer_state = match &mut next_cache { + Some(values) => std::mem::take(&mut values[layer_idx]), + None => (None, None), + }; + let block_output = layer.forward_t( + &hidden_state, + Some(extended_attention_mask), + position_bias.as_ref(), + encoder_hidden_states, + encoder_extended_attention_mask.as_ref(), + encoder_decoder_position_bias.as_ref(), + layer_state, + train, + ); + if layer_idx == 0 { + position_bias = block_output.self_attention_position_bias; + encoder_decoder_position_bias = block_output.cross_attention_position_bias; + } + hidden_state = block_output.hidden_states; + attention_weights = block_output.cross_attention_weights; + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); + }; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(std::mem::take(&mut attention_weights.unwrap())); + }; + if let Some(value) = &mut next_cache { + value[layer_idx] = block_output.cache + }; + } + + let hidden_state = hidden_state + .apply(&self.final_layer_norm) + .apply_t(&self.dropout, train); + + Ok(LongT5StackOutput { + hidden_state, + all_hidden_states, + all_attentions, + next_cache, + }) + } +} + +pub type LongT5BlockOutput = T5BlockOutput; +pub type LongT5StackOutput = T5StackOutput; diff --git a/src/longt5/layer_norm.rs b/src/longt5/layer_norm.rs new file mode 100644 index 000000000..459ee3e66 --- /dev/null +++ b/src/longt5/layer_norm.rs @@ -0,0 +1,15 @@ +// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::t5::T5LayerNorm; + +pub type LongT5LayerNorm = T5LayerNorm; diff --git a/src/longt5/longt5_model.rs b/src/longt5/longt5_model.rs new file mode 100644 index 000000000..53fb9d2cf --- /dev/null +++ b/src/longt5/longt5_model.rs @@ -0,0 +1,894 @@ +// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team. +// Copyright 2022 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::longt5::encoder::LongT5Stack; +use crate::longt5::LayerState; +use crate::pipelines::common::{ModelType, TokenizerOption}; +use crate::pipelines::generation_utils::private_generation_utils::{ + PreparedInput, PrivateLanguageGenerator, +}; +use crate::pipelines::generation_utils::{ + Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator, +}; +use crate::t5::{FeedForwardProj, T5Config, T5ModelOutput, TaskSpecificParams}; +use crate::{Config, RustBertError}; +use rust_tokenizers::tokenizer::{T5Tokenizer, TruncationStrategy}; +use rust_tokenizers::vocab::T5Vocab; +use serde::{Deserialize, Serialize}; +use std::borrow::Borrow; +use tch::nn::{embedding, LinearConfig}; +use tch::{nn, Tensor}; + +/// # LongT5 Pretrained model weight files +pub struct LongT5ModelResources; + +/// # LongT5 Pretrained model config files +pub struct LongT5ConfigResources; + +/// # LongT5 Pretrained model vocab files +pub struct LongT5VocabResources; + +impl LongT5ModelResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = ( + "longt5-tglobal-base-book-summary/model", + "https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/rust_model.ot", + ); +} + +impl LongT5ConfigResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = ( + "longt5-tglobal-base-book-summary/config", + "https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/config.json", + ); +} + +impl LongT5VocabResources { + /// Shared under Apache 2.0 license at . Modified with conversion to C-array format. + pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = ( + "longt5-tglobal-base-book-summary/spiece", + "https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/spiece.model", + ); +} + +#[derive(Clone, Debug, Serialize, Deserialize, Copy)] +#[serde(rename_all = "kebab-case")] +/// # Options for LongT5 encoder attention type +pub enum EncoderAttentionType { + /// Local + Local, + /// Transient Global + TransientGlobal, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +/// # LongT5 model configuration +/// Defines the LongT5 model architecture (e.g. number of layers, hidden layer size, label mapping...) +pub struct LongT5Config { + pub dropout_rate: f64, + pub d_model: i64, + pub d_ff: i64, + pub d_kv: i64, + pub decoder_start_token_id: Option, + pub bos_token_id: Option, + pub eos_token_id: Option, + pub initializer_factor: f64, + pub is_encoder_decoder: Option, + pub layer_norm_epsilon: f64, + pub num_heads: i64, + pub num_layers: i64, + pub num_decoder_layers: Option, + pub local_radius: i64, + pub global_block_size: i64, + pub output_past: Option, + pub pad_token_id: Option, + pub relative_attention_num_buckets: i64, + pub relative_attention_max_distance: Option, + pub encoder_attention_type: Option, + pub vocab_size: i64, + pub feed_forward_proj: Option, + pub tie_word_embeddings: Option, + pub task_specific_params: Option, + pub output_attentions: Option, + pub output_hidden_states: Option, +} + +impl Config for LongT5Config {} + +impl Default for LongT5Config { + fn default() -> Self { + LongT5Config { + dropout_rate: 0.1, + d_model: 512, + d_ff: 2048, + d_kv: 64, + decoder_start_token_id: None, + bos_token_id: None, + eos_token_id: Some(1), + initializer_factor: 1.0, + is_encoder_decoder: None, + layer_norm_epsilon: 1e-6, + num_heads: 8, + num_layers: 6, + num_decoder_layers: None, + local_radius: 127, + global_block_size: 16, + output_past: None, + pad_token_id: Some(0), + relative_attention_num_buckets: 32, + relative_attention_max_distance: Some(128), + encoder_attention_type: Some(EncoderAttentionType::Local), + vocab_size: 32128, + feed_forward_proj: Some(FeedForwardProj::Relu), + tie_word_embeddings: None, + task_specific_params: None, + output_attentions: None, + output_hidden_states: None, + } + } +} + +impl From<&LongT5Config> for T5Config { + fn from(val: &LongT5Config) -> T5Config { + T5Config { + dropout_rate: val.dropout_rate, + d_model: val.d_model, + d_ff: val.d_ff, + d_kv: val.d_kv, + decoder_start_token_id: val.decoder_start_token_id, + bos_token_id: None, + eos_token_id: val.eos_token_id, + initializer_factor: val.initializer_factor, + is_encoder_decoder: val.is_encoder_decoder, + layer_norm_epsilon: val.layer_norm_epsilon, + num_heads: val.num_heads, + num_layers: val.num_layers, + output_past: val.output_past, + pad_token_id: val.pad_token_id, + relative_attention_num_buckets: val.relative_attention_num_buckets, + relative_attention_max_distance: val.relative_attention_max_distance, + vocab_size: val.vocab_size, + feed_forward_proj: val.feed_forward_proj, + tie_word_embeddings: val.tie_word_embeddings, + task_specific_params: val.task_specific_params.clone(), + output_attentions: val.output_attentions, + output_hidden_states: val.output_hidden_states, + } + } +} + +/// # LongT5 Base model +/// Base architecture for LongT5 model. Usually complemented with a task-specific head, such as a language model head. +/// It is made of the following blocks: +/// - `encoder`: `T5Stack` (transformer) made of a vector of encoding layers +/// - `decoder`: `T5Stack` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention. +/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values) +/// - `embeddings`: `nn::Embedding` Shared embeddings for the encoder and decoder. +pub struct LongT5Model { + pub(crate) encoder: LongT5Stack, + decoder: LongT5Stack, + pub(crate) embeddings: nn::Embedding, +} + +impl LongT5Model { + /// Build a new `LongT5Model` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the LongT5 model + /// * `config` - `LongT5Config` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::longt5::{LongT5Config, LongT5Model}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = LongT5Config::from_file(config_path); + /// let long_t5: LongT5Model = LongT5Model::new(&p.root() / "longt5", &config); + /// ``` + pub fn new<'p, P>(p: P, config: &LongT5Config) -> LongT5Model + where + P: Borrow>, + { + let p = p.borrow(); + + let embeddings: nn::Embedding = embedding( + p / "shared", + config.vocab_size, + config.d_model, + Default::default(), + ); + + let encoder = LongT5Stack::new( + p / "encoder", + config, + false, + false, + config.output_attentions.unwrap_or(false), + config.output_hidden_states.unwrap_or(false), + ); + let decoder = LongT5Stack::new( + p / "decoder", + config, + true, + true, + config.output_attentions.unwrap_or(false), + config.output_hidden_states.unwrap_or(false), + ); + + LongT5Model { + encoder, + decoder, + embeddings, + } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided. + /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. + /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). + /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided. + /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. + /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. + /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided. + /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `LongT5ModelOutput` containing: + /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state + /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state + /// - `cache` - `Option>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder. + /// - `all_encoder_hidden_states` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_encoder_attentions` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_decoder_hidden_states` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// - `all_decoder_attentions` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::longt5::{LongT5Config, LongT5Model}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = LongT5Config::from_file(config_path); + /// # let longt5_model: LongT5Model = LongT5Model::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// longt5_model.forward_t( + /// Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// None, + /// Some(&target_tensor), + /// Some(&decoder_attention_mask), + /// None, + /// None, + /// None, + /// false, + /// ) + /// }); + /// ``` + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + decoder_attention_mask: Option<&Tensor>, + input_embeds: Option<&Tensor>, + decoder_input_embeds: Option<&Tensor>, + old_layer_states: Option, Option)>>, + train: bool, + ) -> Result { + let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) = + if encoder_outputs.is_none() { + let encoder_output = self.encoder.forward_t( + input_ids, + attention_mask, + None, + None, + input_embeds, + &self.embeddings, + None, + train, + )?; + ( + Some(encoder_output.hidden_state), + encoder_output.all_hidden_states, + encoder_output.all_attentions, + ) + } else { + (None, None, None) + }; + + let encoder_output = + encoder_outputs.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap()); + + let decoder_output = self + .decoder + .forward_t( + decoder_input_ids, + decoder_attention_mask, + Some(encoder_output), + attention_mask, + decoder_input_embeds, + &self.embeddings, + old_layer_states, + train, + ) + .unwrap(); + Ok(LongT5ModelOutput { + decoder_output: decoder_output.hidden_state, + encoder_hidden_state: calc_hidden_states, + next_cache: decoder_output.next_cache, + all_decoder_hidden_states: decoder_output.all_hidden_states, + all_decoder_attentions: decoder_output.all_attentions, + all_encoder_hidden_states, + all_encoder_attentions, + }) + } +} + +/// # LongT5 Model for conditional generation +/// LongT5 model with a vocabulary decoding head +/// It is made of the following blocks: +/// - `base_model`: `LongT5Model` Base LongT5 model +/// - `model_dim`: `f64` representation of the model dimension for scaling of the generated logits +pub struct LongT5ForConditionalGeneration { + base_model: LongT5Model, + model_dim: f64, + tie_word_embeddings: bool, + lm_head: Option, +} + +impl LongT5ForConditionalGeneration { + /// Build a new `LongT5ForConditionalGeneration` + /// + /// # Arguments + /// + /// * `p` - Variable store path for the root of the BART model + /// * `config` - `LongT5Config` object defining the model architecture + /// + /// # Example + /// + /// ```no_run + /// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration}; + /// use rust_bert::Config; + /// use std::path::Path; + /// use tch::{nn, Device}; + /// + /// let config_path = Path::new("path/to/config.json"); + /// let device = Device::Cpu; + /// let p = nn::VarStore::new(device); + /// let config = LongT5Config::from_file(config_path); + /// let longt5 = LongT5ForConditionalGeneration::new(&p.root() / "t5", &config); + /// ``` + pub fn new<'p, P>(p: P, config: &LongT5Config) -> LongT5ForConditionalGeneration + where + P: Borrow>, + { + let p = p.borrow(); + + let base_model = LongT5Model::new(p, config); + let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true); + + let lm_head = if !tie_word_embeddings { + Some(nn::linear( + p / "lm_head", + config.d_model, + config.vocab_size, + LinearConfig { + bias: false, + ..Default::default() + }, + )) + } else { + None + }; + + LongT5ForConditionalGeneration { + base_model, + model_dim: config.d_model as f64, + tie_word_embeddings, + lm_head, + } + } + + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided. + /// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked. + /// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*). + /// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided. + /// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked. + /// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided. + /// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided. + /// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `longT5ModelOutput` containing: + /// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each sequence position and vocabulary item + /// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state + /// - `cache` - `Option>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder. + /// - `all_encoder_hidden_states` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_encoder_attentions` - `Option>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*) + /// - `all_decoder_hidden_states` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// - `all_decoder_attentions` - `Option>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*) + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = LongT5Config::from_file(config_path); + /// # let longt5_model: LongT5ForConditionalGeneration = LongT5ForConditionalGeneration::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// longt5_model.forward_t( + /// Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// None, + /// Some(&target_tensor), + /// Some(&decoder_attention_mask), + /// None, + /// None, + /// None, + /// false, + /// ) + /// }); + /// ``` + pub fn forward_t( + &self, + input_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + decoder_attention_mask: Option<&Tensor>, + input_embeds: Option<&Tensor>, + decoder_input_embeds: Option<&Tensor>, + old_layer_states: Option, Option)>>, + train: bool, + ) -> Result { + let base_model_output = self.base_model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + decoder_attention_mask, + input_embeds, + decoder_input_embeds, + old_layer_states, + train, + )?; + + let lm_logits = if self.tie_word_embeddings { + base_model_output + .decoder_output + .linear::(&self.base_model.embeddings.ws, None) + * (self.model_dim.powf(-0.5)) + } else { + base_model_output + .decoder_output + .apply(self.lm_head.as_ref().unwrap()) + }; + + Ok(T5ModelOutput { + decoder_output: lm_logits, + ..base_model_output + }) + } + + pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { + self.base_model + .encoder + .forward_t( + Some(input_ids), + attention_mask, + None, + None, + None, + &self.base_model.embeddings, + None, + false, + ) + .unwrap() + .hidden_state + } +} + +impl LMHeadModel for LongT5ForConditionalGeneration { + /// Forward pass through the model + /// + /// # Arguments + /// + /// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`) + /// * `cache` - `Cache` object containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding. + /// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1 + /// * `input_embeds` - Unused for LongT5 + /// * `token_type_ids` - Unused for LongT5 + /// * `position_ids` - Unused for LongT5 + /// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks. + /// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). + /// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference. + /// + /// # Returns + /// + /// * `LMModelOutput` containing: + /// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position + /// - `cache` - `T5Cache` made of `Option>)>>` of length *n_layer* containing the encoder past keys and values for + /// both the self attention and the encoder cross attention of each layer of the decoder. + /// + /// # Example + /// + /// ```no_run + /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use rust_bert::Config; + /// # use std::path::Path; + /// # use tch::kind::Kind::{Int64, Double}; + /// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration}; + /// # let config_path = Path::new("path/to/config.json"); + /// # let vocab_path = Path::new("path/to/vocab.txt"); + /// # let device = Device::Cpu; + /// # let vs = nn::VarStore::new(device); + /// # let config = LongT5Config::from_file(config_path); + /// # let longt5_model: LongT5ForConditionalGeneration = LongT5ForConditionalGeneration::new(&vs.root(), &config); + /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); + /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); + /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); + /// let encoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// let decoder_attention_mask = + /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); + /// + /// let model_output = no_grad(|| { + /// longt5_model.forward_t( + /// Some(&input_tensor), + /// Some(&encoder_attention_mask), + /// None, + /// Some(&target_tensor), + /// Some(&decoder_attention_mask), + /// None, + /// None, + /// None, + /// false, + /// ) + /// }); + /// ``` + fn forward_t( + &self, + input_ids: Option<&Tensor>, + cache: Cache, + attention_mask: Option<&Tensor>, + _token_type_ids: Option<&Tensor>, + _position_ids: Option<&Tensor>, + _input_embeds: Option<&Tensor>, + encoder_outputs: Option<&Tensor>, + decoder_input_ids: Option<&Tensor>, + train: bool, + ) -> Result { + let base_model_output = match cache { + Cache::LongT5Cache(cached_layer_states) => self.base_model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + cached_layer_states, + train, + )?, + Cache::None => self.base_model.forward_t( + input_ids, + attention_mask, + encoder_outputs, + decoder_input_ids, + None, + None, + None, + None, + train, + )?, + _ => { + return Err(RustBertError::ValueError( + "Cache not compatible with LongT5 Model".into(), + )); + } + }; + + let lm_logits = if self.tie_word_embeddings { + base_model_output + .decoder_output + .linear::(&self.base_model.embeddings.ws, None) + * (self.model_dim.powf(-0.5)) + } else { + base_model_output + .decoder_output + .apply(self.lm_head.as_ref().unwrap()) + }; + + Ok(LMModelOutput { + lm_logits, + cache: Cache::LongT5Cache(base_model_output.next_cache), + }) + } +} + +/// Container holding a LongT5 model output. +pub type LongT5ModelOutput = T5ModelOutput; + +pub struct LongT5Generator { + model: LongT5ForConditionalGeneration, + tokenizer: TokenizerOption, + var_store: nn::VarStore, + generate_config: GenerateConfig, + bos_token_id: Option, + eos_token_ids: Option>, + pad_token_id: Option, + is_encoder_decoder: bool, + vocab_size: i64, + decoder_start_id: Option, + max_position_embeddings: i64, +} + +impl LongT5Generator { + pub fn new(generate_config: GenerateConfig) -> Result { + let vocab_path = generate_config.vocab_resource.get_local_path()?; + + let tokenizer = TokenizerOption::from_file( + ModelType::LongT5, + vocab_path.to_str().unwrap(), + None, + false, + None, + None, + )?; + + Self::new_with_tokenizer(generate_config, tokenizer) + } + + pub fn new_with_tokenizer( + generate_config: GenerateConfig, + tokenizer: TokenizerOption, + ) -> Result { + let config_path = generate_config.config_resource.get_local_path()?; + let weights_path = generate_config.model_resource.get_local_path()?; + let device = generate_config.device; + + generate_config.validate(); + let mut var_store = nn::VarStore::new(device); + + let config = LongT5Config::from_file(config_path); + let model = LongT5ForConditionalGeneration::new(var_store.root(), &config); + var_store.load(weights_path)?; + + let bos_token_id = config.bos_token_id; + let eos_token_ids = Some(match config.eos_token_id { + Some(value) => vec![value], + None => vec![1], + }); + let pad_token_id = Some(config.pad_token_id.unwrap_or(0)); + let vocab_size = config.vocab_size; + let is_encoder_decoder = true; + let decoder_start_id = pad_token_id; + // longT5 do not have an embedding matrix for position IDs and relies on relative positions instead + let max_position_embeddings = i64::MAX; + + Ok(LongT5Generator { + model, + tokenizer, + var_store, + generate_config, + bos_token_id, + eos_token_ids, + pad_token_id, + is_encoder_decoder, + vocab_size, + decoder_start_id, + max_position_embeddings, + }) + } +} + +impl PrivateLanguageGenerator + for LongT5Generator +{ + fn get_model(&self) -> &LongT5ForConditionalGeneration { + &self.model + } + fn _get_tokenizer(&self) -> &TokenizerOption { + &self.tokenizer + } + fn get_var_store(&self) -> &nn::VarStore { + &self.var_store + } + fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + &mut self.var_store + } + fn get_config(&self) -> &GenerateConfig { + &self.generate_config + } + fn get_bos_id(&self) -> Option { + self.bos_token_id + } + fn get_eos_ids(&self) -> Option<&Vec> { + self.eos_token_ids.as_ref() + } + fn get_pad_id(&self) -> Option { + self.pad_token_id + } + fn is_encoder_decoder(&self) -> bool { + self.is_encoder_decoder + } + fn get_vocab_size(&self) -> i64 { + self.vocab_size + } + fn get_decoder_start_id(&self) -> Option { + self.decoder_start_id + } + fn get_max_positions_embeddings(&self) -> i64 { + self.max_position_embeddings + } + + fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option { + Some(self.get_model().encode(input_ids, attention_mask)) + } + + fn prepare_inputs_for_generation<'a>( + &self, + input_ids: Tensor, + encoder_outputs: Option<&'a Tensor>, + past: Cache, + attention_mask: Tensor, + ) -> PreparedInput<'a> { + match past { + Cache::LongT5Cache(past) => PreparedInput { + prepared_input: None, + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: encoder_outputs, + prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)), + prepared_position_ids: None, + prepared_past: Cache::LongT5Cache(past), + }, + Cache::None => PreparedInput { + prepared_input: None, + prepared_attention_mask: Some(attention_mask), + prepared_encoder_output: encoder_outputs, + prepared_decoder_input: Some(input_ids), + prepared_position_ids: None, + prepared_past: Cache::LongT5Cache(None), + }, + _ => panic!("Cache type incompatible with longT5"), + } + } + + fn encode_prompt_text( + &self, + prompt_text: &[S], + max_len: Option, + pad_token_id: Option, + ) -> Tensor + where + S: AsRef + Sync, + { + let tokens = self._get_tokenizer().encode_list( + prompt_text, + max_len + .map(|max_len| max_len as usize) + .unwrap_or(usize::MAX), + &TruncationStrategy::LongestFirst, + 0, + ); + let token_ids = tokens + .into_iter() + .map(|tokenized_input| tokenized_input.token_ids) + .collect::>>(); + + let max_len = token_ids.iter().map(|input| input.len()).max().unwrap(); + + let pad_token = match pad_token_id { + Some(value) => value, + None => self._get_tokenizer().get_unk_id(), + }; + + let token_ids = token_ids + .into_iter() + .map(|mut input| { + let temp = vec![pad_token; max_len - input.len()]; + input.extend(temp); + input + }) + .map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device())) + .collect::>(); + + Tensor::stack(&token_ids, 0) + } + + fn reorder_cache( + &self, + past: &mut Cache, + encoder_outputs: Option, + beam_indices: &Tensor, + ) -> Option { + match past { + Cache::LongT5Cache(old_cache_option) => match old_cache_option { + Some(old_cache) => { + for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() { + if self_layer_state.is_some() { + self_layer_state + .as_mut() + .unwrap() + .reorder_cache(beam_indices) + }; + if encoder_layer_state.is_some() { + encoder_layer_state + .as_mut() + .unwrap() + .reorder_cache(beam_indices) + }; + } + } + None => {} + }, + Cache::None => {} + _ => { + panic!("Invalid cache for LongT5 model"); + } + }; + encoder_outputs + } +} + +impl LanguageGenerator for LongT5Generator {} diff --git a/src/longt5/mod.rs b/src/longt5/mod.rs new file mode 100644 index 000000000..091040fdb --- /dev/null +++ b/src/longt5/mod.rs @@ -0,0 +1,59 @@ +//! # LongT5 (Efficient Text-To-Text Transformer for Long Sequences) +//! +//! Implementation of the LongT5 language model ([LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) Guo, Ainslie, Uthus, Ontanon, Ni, Sung, Yang, 2021). +//! The base model is implemented in the `longt5_model::LongT5Model` struct. This model includes a language model head: `longt5_model::LongT5ForConditionalGeneration` +//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! +//! # Model set-up and pre-trained weights loading +//! +//! All models expect the following resources: +//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) +//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. +//! - `T5Tokenizer` using a `spiece.model` sentence piece model +//! +//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. +//! +//! ```no_run +//! # fn main() -> anyhow::Result<()> { +//! # +//! use tch::{nn, Device}; +//! # use std::path::PathBuf; +//! use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration}; +//! use rust_bert::resources::{LocalResource, ResourceProvider}; +//! use rust_bert::Config; +//! use rust_tokenizers::tokenizer::T5Tokenizer; +//! +//! let config_resource = LocalResource { +//! local_path: PathBuf::from("path/to/config.json"), +//! }; +//! let sentence_piece_resource = LocalResource { +//! local_path: PathBuf::from("path/to/spiece.model"), +//! }; +//! let weights_resource = LocalResource { +//! local_path: PathBuf::from("path/to/model.ot"), +//! }; +//! let config_path = config_resource.get_local_path()?; +//! let spiece_path = sentence_piece_resource.get_local_path()?; +//! let weights_path = weights_resource.get_local_path()?; +//! +//! let device = Device::cuda_if_available(); +//! let mut vs = nn::VarStore::new(device); +//! let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true); +//! let config = LongT5Config::from_file(config_path); +//! let longt5_model = LongT5ForConditionalGeneration::new(&vs.root(), &config); +//! vs.load(weights_path)?; +//! +//! # Ok(()) +//! # } +//! ``` + +mod attention; +mod encoder; +mod layer_norm; +mod longt5_model; + +pub use attention::LayerState; +pub use longt5_model::{ + LongT5Config, LongT5ConfigResources, LongT5ForConditionalGeneration, LongT5Generator, + LongT5Model, LongT5ModelResources, LongT5VocabResources, +}; diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index 0b3ea0579..d8555dd2b 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -28,6 +28,7 @@ use crate::fnet::FNetConfig; use crate::gpt2::Gpt2Config; use crate::gpt_neo::GptNeoConfig; use crate::longformer::LongformerConfig; +use crate::longt5::LongT5Config; use crate::m2m_100::M2M100Config; use crate::marian::MarianConfig; use crate::mbart::MBartConfig; @@ -71,6 +72,8 @@ pub enum ModelType { MobileBert, #[serde(alias = "t5")] T5, + #[serde(alias = "longt5")] + LongT5, #[serde(alias = "albert")] Albert, XLNet, @@ -108,6 +111,8 @@ pub enum ConfigOption { OpenAiGpt(OpenAiGptConfig), /// T5 configuration T5(T5Config), + /// LongT5 configuration + LongT5(LongT5Config), /// Albert configuration Albert(AlbertConfig), /// XLNet configuration @@ -187,6 +192,7 @@ impl ConfigOption { ModelType::Marian => ConfigOption::Marian(MarianConfig::from_file(path)), ModelType::MobileBert => ConfigOption::MobileBert(MobileBertConfig::from_file(path)), ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)), + ModelType::LongT5 => ConfigOption::LongT5(LongT5Config::from_file(path)), ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)), ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)), ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)), @@ -276,6 +282,7 @@ impl ConfigOption { .as_ref() .expect("No label dictionary (id2label) provided in configuration file"), Self::T5(_) => panic!("T5 does not use a label mapping"), + Self::LongT5(_) => panic!("LongT5 does not use a label mapping"), Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"), Self::GPT2(_) => panic!("GPT2 does not use a label mapping"), Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"), @@ -294,6 +301,7 @@ impl ConfigOption { Self::Marian(config) => Some(config.max_position_embeddings), Self::MobileBert(config) => Some(config.max_position_embeddings), Self::T5(_) => None, + Self::LongT5(_) => None, Self::Albert(config) => Some(config.max_position_embeddings), Self::XLNet(_) => None, Self::GPT2(config) => Some(config.n_positions), @@ -473,7 +481,7 @@ impl TokenizerOption { lower_case, )?) } - ModelType::T5 => { + ModelType::T5 | ModelType::LongT5 => { if strip_accents.is_some() { return Err(RustBertError::InvalidConfigurationError(format!( "Optional input `strip_accents` set to value {} but cannot be used by {:?}", diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 710365c57..f9787c471 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -219,6 +219,7 @@ pub enum Cache { GPT2Cache(Option>), BARTCache(Option, Option)>>), T5Cache(Option, Option)>>), + LongT5Cache(Option, Option)>>), XLNetCache(Option>>), ReformerCache(Option>>), ProphetNetCache(Option, Option)>>), diff --git a/src/pipelines/summarization.rs b/src/pipelines/summarization.rs index 4742e01af..ba45b6559 100644 --- a/src/pipelines/summarization.rs +++ b/src/pipelines/summarization.rs @@ -73,6 +73,7 @@ use crate::prophetnet::ProphetNetConditionalGenerator; use crate::resources::ResourceProvider; use crate::t5::T5Generator; +use crate::longt5::LongT5Generator; #[cfg(feature = "remote")] use crate::{ bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources}, @@ -219,6 +220,8 @@ pub enum SummarizationOption { Bart(BartGenerator), /// Summarizer based on T5 model T5(T5Generator), + /// Summarizer based on LongT5 model + LongT5(LongT5Generator), /// Summarizer based on ProphetNet model ProphetNet(ProphetNetConditionalGenerator), /// Summarizer based on Pegasus model @@ -232,6 +235,9 @@ impl SummarizationOption { config.into(), )?)), ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)), + ModelType::LongT5 => Ok(SummarizationOption::LongT5(LongT5Generator::new( + config.into(), + )?)), ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet( ProphetNetConditionalGenerator::new(config.into())?, )), @@ -250,6 +256,7 @@ impl SummarizationOption { match *self { Self::Bart(_) => ModelType::Bart, Self::T5(_) => ModelType::T5, + Self::LongT5(_) => ModelType::LongT5, Self::ProphetNet(_) => ModelType::ProphetNet, Self::Pegasus(_) => ModelType::Pegasus, } @@ -271,6 +278,11 @@ impl SummarizationOption { .into_iter() .map(|output| output.text) .collect(), + Self::LongT5(ref model) => model + .generate(prompt_texts, None) + .into_iter() + .map(|output| output.text) + .collect(), Self::ProphetNet(ref model) => model .generate(prompt_texts, None) .into_iter() diff --git a/src/prophetnet/decoder.rs b/src/prophetnet/decoder.rs index 6e83c12c9..5b647ebea 100644 --- a/src/prophetnet/decoder.rs +++ b/src/prophetnet/decoder.rs @@ -12,7 +12,7 @@ use crate::common::dropout::Dropout; use crate::common::embeddings::process_ids_embeddings_pair; -use crate::common::kind::get_negative_infinity; +use crate::common::kind::get_min; use crate::prophetnet::attention::{ compute_all_stream_relative_buckets, LayerState, ProphetNetAttention, ProphetNetFeedForward, ProphetNetNgramAttention, @@ -26,7 +26,7 @@ use tch::{nn, Device, Kind, Tensor}; fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor { let left_block = Tensor::ones(&[ngram, sequence_length, sequence_length], (kind, device)) - * get_negative_infinity(kind).unwrap(); + * get_min(kind).unwrap(); let right_block = left_block.copy(); for stream_idx in 0..ngram { let _ = right_block.get(stream_idx).fill_diagonal_(0, false); @@ -515,7 +515,7 @@ impl ProphetNetDecoder { let causal_mask = Tensor::full( &[sequence_length, sequence_length], - get_negative_infinity(hidden_states.kind()).unwrap(), + get_min(hidden_states.kind()).unwrap(), (hidden_states.kind(), hidden_states.device()), ) .triu_(1); diff --git a/src/t5/attention.rs b/src/t5/attention.rs index b5be649ae..1a1dea8c0 100644 --- a/src/t5/attention.rs +++ b/src/t5/attention.rs @@ -43,6 +43,37 @@ impl LayerState { } } +pub fn get_relative_position_bucket( + relative_position: &Tensor, + bidirectional: bool, + num_buckets: i64, + max_distance: i64, +) -> Tensor { + let n = -relative_position; + let mut num_buckets = num_buckets; + let mut ret = n.zeros_like(); + let n = if bidirectional { + num_buckets /= 2; + ret += n.lt(0).to_kind(Kind::Int64) * num_buckets; + n.abs() + } else { + n.max_other(&n.zeros_like()) + }; + + let max_exact = num_buckets / 2; + let is_small = n.lt(max_exact); + + let value_if_large: Tensor = ((n.to_kind(Kind::Float) / max_exact as f64).log2() + / (max_distance as f64 / max_exact as f64).log2() + * (num_buckets - max_exact) as f64) + .to_kind(Kind::Int64) + + max_exact; + + let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1)); + ret += n.where_self(&is_small, &value_if_large); + ret +} + #[derive(Debug)] pub struct T5Attention { is_decoder: bool, @@ -142,7 +173,7 @@ impl T5Attention { train: bool, ) -> (Tensor, Option, Option, Option) { let input_size = hidden_states.size(); - let (bs, seq_length, _) = (input_size[0], input_size[1], input_size[2]); + let (bs, seq_length) = (input_size[0], input_size[1]); let real_seq_length = if layer_state.is_some() { match query_length { @@ -245,44 +276,12 @@ impl T5Attention { (context, attention_weights, position_bias, layer_state) } - fn get_relative_position_bucket( - &self, - relative_position: &Tensor, - bidirectional: bool, - num_buckets: i64, - max_distance: i64, - ) -> Tensor { - let n = -relative_position; - let mut num_buckets = num_buckets; - let mut ret = n.zeros_like(); - let n = if bidirectional { - num_buckets /= 2; - ret += n.lt(0).to_kind(Kind::Int64) * num_buckets; - n.abs() - } else { - n.max_other(&n.zeros_like()) - }; - - let max_exact = num_buckets / 2; - let is_small = n.lt(max_exact); - - let value_if_large: Tensor = ((n.to_kind(Kind::Float) / max_exact as f64).log2() - / (max_distance as f64 / max_exact as f64).log2() - * (num_buckets - max_exact) as f64) - .to_kind(Kind::Int64) - + max_exact; - - let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1)); - ret += n.where_self(&is_small, &value_if_large); - ret - } - fn compute_bias(&self, q_len: i64, k_len: i64, device: Device) -> Tensor { let context_position = Tensor::arange(q_len, (Kind::Int64, device)).unsqueeze(1); let memory_position = Tensor::arange(k_len, (Kind::Int64, device)).unsqueeze(0); let relative_position = memory_position - context_position; - let rp_bucket = self.get_relative_position_bucket( + let rp_bucket = get_relative_position_bucket( &relative_position, self.is_bidirectional, self.relative_attention_num_buckets, diff --git a/src/t5/encoder.rs b/src/t5/encoder.rs index 16903f6df..c1287b26c 100644 --- a/src/t5/encoder.rs +++ b/src/t5/encoder.rs @@ -17,20 +17,21 @@ use crate::t5::attention::{LayerState, T5LayerCrossAttention, T5LayerSelfAttenti use crate::t5::layer_norm::T5LayerNorm; use crate::t5::t5_model::FeedForwardProj; use crate::t5::T5Config; -use crate::Activation::gelu_new; +use crate::Activation::{gelu_new, relu}; use crate::RustBertError; use std::borrow::{Borrow, BorrowMut}; use tch::nn::LinearConfig; use tch::{nn, Kind, Scalar, Tensor}; -pub struct T5DenseReluDense { +pub struct T5DenseActDense { wi: nn::Linear, wo: nn::Linear, dropout: Dropout, + activation_function: TensorFunction, } -impl T5DenseReluDense { - pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseReluDense +impl T5DenseActDense { + pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseActDense where P: Borrow>, { @@ -42,29 +43,36 @@ impl T5DenseReluDense { let wi = nn::linear(p / "wi", config.d_model, config.d_ff, linear_config); let wo = nn::linear(p / "wo", config.d_ff, config.d_model, linear_config); let dropout = Dropout::new(config.dropout_rate); + let activation_function = match config.feed_forward_proj { + None | Some(FeedForwardProj::Relu) => relu.get_function(), + Some(FeedForwardProj::GatedGelu) => gelu_new.get_function(), + }; - T5DenseReluDense { wi, wo, dropout } + T5DenseActDense { + wi, + wo, + dropout, + activation_function, + } } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { - hidden_states - .apply(&self.wi) - .relu() + self.activation_function.get_fn()(&hidden_states.apply(&self.wi)) .apply_t(&self.dropout, train) .apply(&self.wo) } } -pub struct T5DenseGatedGeluDense { +pub struct T5DenseGatedActDense { wi_0: nn::Linear, wi_1: nn::Linear, wo: nn::Linear, dropout: Dropout, - activation: TensorFunction, + activation_function: TensorFunction, } -impl T5DenseGatedGeluDense { - pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseGatedGeluDense +impl T5DenseGatedActDense { + pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseGatedActDense where P: Borrow>, { @@ -77,19 +85,22 @@ impl T5DenseGatedGeluDense { let wi_1 = nn::linear(p / "wi_1", config.d_model, config.d_ff, linear_config); let wo = nn::linear(p / "wo", config.d_ff, config.d_model, linear_config); let dropout = Dropout::new(config.dropout_rate); - let activation = gelu_new.get_function(); + let activation_function = match config.feed_forward_proj { + None | Some(FeedForwardProj::Relu) => relu.get_function(), + Some(FeedForwardProj::GatedGelu) => gelu_new.get_function(), + }; - T5DenseGatedGeluDense { + T5DenseGatedActDense { wi_0, wi_1, wo, dropout, - activation, + activation_function, } } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { - let hidden_gelu = self.activation.get_fn()(&hidden_states.apply(&self.wi_0)); + let hidden_gelu = self.activation_function.get_fn()(&hidden_states.apply(&self.wi_0)); let hidden_linear = hidden_states.apply(&self.wi_1); (hidden_gelu * hidden_linear) .apply_t(&self.dropout, train) @@ -98,8 +109,8 @@ impl T5DenseGatedGeluDense { } pub enum T5FeedForwardLayer { - T5DenseReluDense(T5DenseReluDense), - T5DenseGatedGeluDense(T5DenseGatedGeluDense), + T5DenseActDense(T5DenseActDense), + T5DenseGatedActDense(T5DenseGatedActDense), } impl T5FeedForwardLayer { @@ -109,20 +120,18 @@ impl T5FeedForwardLayer { { match config.feed_forward_proj.unwrap_or(FeedForwardProj::Relu) { FeedForwardProj::Relu => { - T5FeedForwardLayer::T5DenseReluDense(T5DenseReluDense::new(p, config)) + T5FeedForwardLayer::T5DenseActDense(T5DenseActDense::new(p, config)) } FeedForwardProj::GatedGelu => { - T5FeedForwardLayer::T5DenseGatedGeluDense(T5DenseGatedGeluDense::new(p, config)) + T5FeedForwardLayer::T5DenseGatedActDense(T5DenseGatedActDense::new(p, config)) } } } pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { match self { - T5FeedForwardLayer::T5DenseReluDense(ref model) => { - model.forward_t(hidden_states, train) - } - T5FeedForwardLayer::T5DenseGatedGeluDense(ref model) => { + T5FeedForwardLayer::T5DenseActDense(ref model) => model.forward_t(hidden_states, train), + T5FeedForwardLayer::T5DenseGatedActDense(ref model) => { model.forward_t(hidden_states, train) } } @@ -217,7 +226,7 @@ impl T5Block { } } - fn clamp_hidden_states(hidden_states: Tensor) -> Tensor { + pub(crate) fn clamp_hidden_states(hidden_states: Tensor) -> Tensor { if (hidden_states.kind() != Kind::Float) & bool::from(hidden_states.isinf().any()) { let clamp_value = match hidden_states.kind() { Kind::Half => half::f16::MAX.to_f64() - 1000., diff --git a/src/t5/mod.rs b/src/t5/mod.rs index 3c4790839..b13a670e7 100644 --- a/src/t5/mod.rs +++ b/src/t5/mod.rs @@ -54,6 +54,10 @@ mod layer_norm; mod t5_model; pub use attention::LayerState; +pub(crate) use attention::{get_relative_position_bucket, T5Attention, T5LayerCrossAttention}; +pub(crate) use encoder::{T5Block, T5BlockOutput, T5LayerFF, T5StackOutput}; +pub(crate) use layer_norm::T5LayerNorm; +pub(crate) use t5_model::{FeedForwardProj, TaskSpecificParams}; pub use t5_model::{ T5Config, T5ConfigResources, T5ForConditionalGeneration, T5ForSentenceEmbeddings, T5Generator, T5Model, T5ModelOutput, T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages, diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 3e4058188..190a925e8 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -147,7 +147,7 @@ pub struct T5Config { pub vocab_size: i64, pub feed_forward_proj: Option, pub tie_word_embeddings: Option, - task_specific_params: Option, + pub task_specific_params: Option, pub output_attentions: Option, pub output_hidden_states: Option, } @@ -250,7 +250,7 @@ impl T5Model { /// /// # Arguments /// - /// * `p` - Variable store path for the root of the BART model + /// * `p` - Variable store path for the root of the T5 model /// * `config` - `T5Config` object defining the model architecture /// /// # Example diff --git a/tests/deberta_v2.rs b/tests/deberta_v2.rs index c8ebd02ec..09ea45aee 100644 --- a/tests/deberta_v2.rs +++ b/tests/deberta_v2.rs @@ -17,7 +17,7 @@ fn deberta_v2_masked_lm() -> anyhow::Result<()> { DebertaV2ConfigResources::DEBERTA_V3_BASE, )); let config_path = config_resource.get_local_path()?; - let device = Device::cuda_if_available(); + let device = Device::Cpu; let vs = nn::VarStore::new(device); let mut config = DebertaV2Config::from_file(config_path); config.output_attentions = Some(true); diff --git a/tests/longt5.rs b/tests/longt5.rs new file mode 100644 index 000000000..8aa0fa897 --- /dev/null +++ b/tests/longt5.rs @@ -0,0 +1,64 @@ +use rust_bert::longt5::{LongT5ConfigResources, LongT5ModelResources, LongT5VocabResources}; +use rust_bert::pipelines::common::ModelType; +use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; +use rust_bert::resources::RemoteResource; + +#[test] +fn test_summarization_longt5() -> anyhow::Result<()> { + // Set-up translation model + let summarization_config = SummarizationConfig { + model_type: ModelType::LongT5, + model_resource: Box::new(RemoteResource::from_pretrained( + LongT5ModelResources::TGLOBAL_BASE_BOOK_SUMMARY, + )), + config_resource: Box::new(RemoteResource::from_pretrained( + LongT5ConfigResources::TGLOBAL_BASE_BOOK_SUMMARY, + )), + vocab_resource: Box::new(RemoteResource::from_pretrained( + LongT5VocabResources::TGLOBAL_BASE_BOOK_SUMMARY, + )), + merges_resource: None, + min_length: 30, + max_length: Some(200), + early_stopping: true, + num_beams: 2, + length_penalty: 2.0, + ..Default::default() + }; + let model = SummarizationModel::new(summarization_config)?; + + let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ +from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \ +from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \ +a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \ +habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \ +used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \ +passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \ +weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \ +contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \ +and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \ +but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \ +\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \ +said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \ +said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors. \ +\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \ +a potentially habitable planet, but further observations will be required to say for sure. \" \ +K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \ +but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \ +on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \ +telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \ +about exoplanets like K2-18b."]; + + let output = model.summarize(&input); + + assert_eq! ( + output[0], + " The first discovery of water on an exoplanet, K2-18b, comes from two different sources: scientists \ +from the University of Montreal and a team from University College London. The scientists found that certain \ +wavelengths of light absorbed by water weakened when the planet was in the way of Earth, indicating that the \ +planet has an atmosphere. The Montreal team analyzed their own results using their own software, and confirmed \ +their conclusion. This is the first such discovery in a planet in its habitable zone - not too hot and not too cold for liquid water to exist." + ); + + Ok(()) +} diff --git a/utils/convert_model.py b/utils/convert_model.py index dedc9f20a..3dc7f0836 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -10,7 +10,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert") - parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head") + parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings") + parser.add_argument("--skip_lm_head", action="store_true", help="Skip language model head") parser.add_argument("--prefix", help="Add a prefix on weight names") parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part") args = parser.parse_args() @@ -24,7 +25,17 @@ for k, v in weights.items(): k = k.replace("gamma", "weight").replace("beta", "bias") if args.skip_embeddings: - if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}: + if k in { + "model.encoder.embed_tokens.weight", + "encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + "decoder.embed_tokens.weight" + }: + continue + if args.skip_lm_head: + if k in { + "lm_head.weight", + }: continue if args.prefix: k = args.prefix + k