Skip to content

Commit

Permalink
Long t5 implementation (#333)
Browse files Browse the repository at this point in the history
* LongT5 config implementation

* LongT5 WiP: utility functions 1

* LongT5 WiP: utility functions (2)

* LongT5 WiP: utility functions (3)

* LongT5 WiP: utility functions (4)

* made T5 FF activations generic, expose T5 modules to crate

* Longt% local attention WIP

* LongT5 local attention

* LongT5 global attention WIP

* LongT5 global attention

* LongT5 attention modules (WIP)

* align LongT5 position bias with T5

* Addition of LongT5Block

* LongT5Stack WiP

* LongT5Stack implementation

* LongT5Model implementation

* LongT5ForConditionalGeneration implementation

* Addition of LongT5Generator, inclusion in pipelines

* LongT5 attention fixes

* Fix MIN/MAX dtype computation, mask for longt5

* Updated min/max and infinity computation across models

* GlobalTransient attention fixes

* Updated changelog, readme, tests, clippy
  • Loading branch information
guillaume-be committed Feb 12, 2023
1 parent 84561ec commit d7e9c03
Show file tree
Hide file tree
Showing 23 changed files with 2,444 additions and 76 deletions.
1 change: 1 addition & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ jobs:
command: test
args: --package rust-bert
--test sentence_embeddings
--test longt5

convert-model:
name: Model conversion test
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]
## Added
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.

## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.

## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)

## [0.20.0] - 2023-01-21
## Added
- Addition of All-MiniLM-L6-V2 model weights
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ M2M100| | | |✅ | | | | |
Electra | |✅| | | | |✅| |
ALBERT |✅|✅|✅| | | |✅| ✅ |
T5 | | | |✅ |✅|✅| | ✅ |
LongT5 | | | |✅ |✅|| | |
XLNet|✅|✅|✅|✅ | | |✅| |
Reformer|✅| |✅|✅ | | |✅| |
ProphetNet| | | |✅ |✅ | | | |
Expand Down
9 changes: 3 additions & 6 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::common::kind::get_min;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
Expand Down Expand Up @@ -273,7 +273,7 @@ pub(crate) fn _make_causal_mask(

let mut mask = Tensor::full(
&[target_length, target_length],
get_negative_infinity(dtype).unwrap(),
get_min(dtype).unwrap(),
(dtype, device),
);
let mask_cond = Tensor::arange(target_length, (dtype, device));
Expand Down Expand Up @@ -311,10 +311,7 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kin
.expand(&[batch_size, 1, target_length, source_length], true)
.totype(dtype);
let inverted_mask: Tensor = 1 - expanded_mask;
inverted_mask.masked_fill(
&inverted_mask.to_kind(Kind::Bool),
get_negative_infinity(dtype).unwrap(),
)
inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
}

pub(crate) fn _prepare_decoder_attention_mask(
Expand Down
19 changes: 19 additions & 0 deletions src/common/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,22 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError>
}
})
}

pub(crate) fn get_min(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),
Kind::Int8 => Scalar::int(i8::MIN.into()),
Kind::Int16 => Scalar::int(i16::MIN.into()),
Kind::Int => Scalar::int(i32::MIN.into()),
Kind::Int64 => Scalar::int(i64::MIN),
Kind::Half => Scalar::float(half::f16::MIN.into()),
Kind::Float => Scalar::float(f32::MIN.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()),
Kind::Double => Scalar::float(f64::MIN),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get min for {kind:?}",
)))
}
})
}
4 changes: 2 additions & 2 deletions src/deberta/deberta_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::bert::{
use crate::common::activations::TensorFunction;
use crate::common::dropout::{Dropout, XDropout};
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::kind::get_negative_infinity;
use crate::common::kind::get_min;
use crate::deberta::embeddings::DebertaEmbeddings;
use crate::deberta::encoder::{DebertaEncoder, DebertaEncoderOutput};
use crate::{Activation, Config, RustBertError};
Expand Down Expand Up @@ -264,7 +264,7 @@ impl Config for DebertaConfig {}
pub fn x_softmax(input: &Tensor, mask: &Tensor, dim: i64) -> Tensor {
let inverse_mask = ((1 - mask) as Tensor).to_kind(Kind::Bool);
input
.masked_fill(&inverse_mask, get_negative_infinity(input.kind()).unwrap())
.masked_fill(&inverse_mask, get_min(input.kind()).unwrap())
.softmax(dim, input.kind())
.masked_fill(&inverse_mask, 0.0)
}
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
//!Electra | |✅| | | | |✅| |
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
//!T5 | | | |✅ |✅|✅| | ✅ |
//!LongT5 | | | |✅ |✅| | | |
//!XLNet|✅|✅|✅|✅ | | |✅| |
//!Reformer|✅| |✅|✅ | | |✅| |
//!ProphetNet| | | |✅ |✅ | | | |
Expand Down Expand Up @@ -695,6 +696,8 @@
// These are used abundantly in this code
#![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)]

extern crate core;

pub mod albert;
pub mod bart;
pub mod bert;
Expand All @@ -707,6 +710,7 @@ pub mod fnet;
pub mod gpt2;
pub mod gpt_neo;
pub mod longformer;
pub mod longt5;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
Expand Down
Loading

0 comments on commit d7e9c03

Please sign in to comment.