From 9be555fc1e980377c75441ac293996592e0cd7be Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Mon, 13 Sep 2021 16:51:30 +0200 Subject: [PATCH] Addition of primitive infinite utilities, rework BERT, BART --- Cargo.toml | 1 + benches/tensor_operations_benchmark.rs | 6 ++-- src/bart/attention.rs | 3 +- src/bart/bart_model.rs | 15 +++++---- src/bart/embeddings.rs | 7 ++--- src/bert/attention.rs | 5 +-- src/bert/bert_model.rs | 40 +++++++++++------------- src/bert/encoder.rs | 19 +++++------- src/common/kind.rs | 43 ++++++++++++++++++++++++++ src/common/mod.rs | 1 + 10 files changed, 90 insertions(+), 50 deletions(-) create mode 100644 src/common/kind.rs diff --git a/Cargo.toml b/Cargo.toml index 4bb4382ef..73731d598 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ cached-path = "0.5.1" lazy_static = "1.4.0" uuid = { version = "0.8.2", features = ["v4"] } thiserror = "1.0.26" +half = "1.7.1" [dev-dependencies] anyhow = "1.0.43" diff --git a/benches/tensor_operations_benchmark.rs b/benches/tensor_operations_benchmark.rs index fba02f540..2ef3d1407 100644 --- a/benches/tensor_operations_benchmark.rs +++ b/benches/tensor_operations_benchmark.rs @@ -3,7 +3,7 @@ extern crate criterion; use criterion::{black_box, Criterion}; use std::time::{Duration, Instant}; -use tch::kind::Kind::Float; +use tch::kind::Kind; use tch::{Device, Tensor}; fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration { @@ -21,8 +21,8 @@ fn bench_tensor_ops(c: &mut Criterion) { unsafe { torch_sys::dummy_cuda_dependency(); } - let input = Tensor::rand(&[32, 128, 512], (Float, Device::cuda_if_available())); - let weights = Tensor::rand(&[512, 512], (Float, Device::cuda_if_available())); + let input = Tensor::rand(&[32, 128, 512], (Kind::Float, Device::cuda_if_available())); + let weights = Tensor::rand(&[512, 512], (Kind::Float, Device::cuda_if_available())); let _ = &input.matmul(&weights); c.bench_function("Matrix multiply ", |b| { diff --git a/src/bart/attention.rs b/src/bart/attention.rs index e0782564b..4b19d9636 100644 --- a/src/bart/attention.rs +++ b/src/bart/attention.rs @@ -13,7 +13,6 @@ use crate::common::dropout::Dropout; use std::borrow::Borrow; -use tch::kind::Kind::Float; use tch::{nn, Tensor}; #[derive(Debug)] @@ -164,7 +163,7 @@ impl BartAttention { attention_weights.view([bs * self.num_heads, target_length, source_length]); }; - attention_weights = attention_weights.softmax(-1, Float); + attention_weights = attention_weights.softmax(-1, attention_weights.kind()); let saved_attention_weights = if self.output_attentions { Some(attention_weights.view((bs, self.num_heads, target_length, source_length))) diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index acb83e279..734853632 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -16,6 +16,7 @@ use crate::bart::decoder::BartDecoder; use crate::bart::encoder::BartEncoder; use crate::common::activations::Activation; use crate::common::dropout::Dropout; +use crate::common::kind::get_negative_infinity; use crate::common::resources::{RemoteResource, Resource}; use crate::gpt2::{ Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources, @@ -33,7 +34,6 @@ use rust_tokenizers::vocab::{RobertaVocab, Vocab}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; -use tch::kind::Kind::Int64; use tch::nn::{embedding, EmbeddingConfig}; use tch::{nn, Device, Kind, Tensor}; @@ -271,9 +271,12 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option) -> Tensor .unsqueeze(1) .unsqueeze(1) .expand(&[batch_size, 1, target_length, source_length], true) - .totype(Kind::Float); + .totype(mask.kind()); let inverted_mask: Tensor = 1 - expanded_mask; - inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), f64::NEG_INFINITY) + inverted_mask.masked_fill( + &inverted_mask.to_kind(Kind::Bool), + get_negative_infinity(inverted_mask.kind()).unwrap(), + ) } pub(crate) fn _prepare_decoder_attention_mask( @@ -308,9 +311,9 @@ pub(crate) fn _prepare_decoder_attention_mask( fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { let index_eos: Tensor = input_ids .ne(pad_token_id) - .sum_dim_intlist(&[-1], true, Int64) + .sum_dim_intlist(&[-1], true, Kind::Int64) - 1; - let output = input_ids.empty_like().to_kind(Int64); + let output = input_ids.empty_like().to_kind(Kind::Int64); output .select(1, 0) .copy_(&input_ids.gather(1, &index_eos, true).squeeze()); @@ -812,7 +815,7 @@ impl BartForSequenceClassification { train, ); let eos_mask = input_ids.eq(self.eos_token_id); - let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64); + let reshape = eos_mask.sum_dim_intlist(&[1], true, Kind::Bool); let sentence_representation = base_model_output .decoder_output .permute(&[2, 0, 1]) diff --git a/src/bart/embeddings.rs b/src/bart/embeddings.rs index faa149eb5..8ae372b26 100644 --- a/src/bart/embeddings.rs +++ b/src/bart/embeddings.rs @@ -12,9 +12,8 @@ // limitations under the License. use std::borrow::Borrow; -use tch::kind::Kind::Int64; use tch::nn::embedding; -use tch::{nn, Tensor}; +use tch::{nn, Kind, Tensor}; /// # Abstraction that holds a embeddings configuration pub enum EmbeddingOption { @@ -67,7 +66,7 @@ impl LearnedPositionalEmbedding { let positions = Tensor::arange_start( past_key_values_length, past_key_values_length + sequence_length, - (Int64, input.device()), + (Kind::Int64, input.device()), ) + self.offset; positions.apply(&self.embedding) } @@ -102,7 +101,7 @@ impl SinusoidalPositionalEmbedding { let positions = Tensor::arange_start( past_key_values_length, past_key_values_length + sequence_length, - (Int64, input.device()), + (Kind::Int64, input.device()), ); positions.apply(&self.embedding) } diff --git a/src/bert/attention.rs b/src/bert/attention.rs index 9a2af5b60..0f45330e7 100644 --- a/src/bert/attention.rs +++ b/src/bert/attention.rs @@ -15,7 +15,6 @@ use crate::bert::bert_model::BertConfig; use crate::common::activations::TensorFunction; use crate::common::dropout::Dropout; use std::borrow::Borrow; -use tch::kind::Kind::Float; use tch::{nn, Tensor}; #[derive(Debug)] @@ -124,7 +123,9 @@ impl BertSelfAttention { query_layer.matmul(&key_layer.transpose(-1, -2)) }; - let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train); + let weights = scores + .softmax(-1, scores.kind()) + .apply_t(&self.dropout, train); let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size); if !self.output_attentions { diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index 6aa9436e0..a2c512251 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -24,7 +24,6 @@ use crate::{Config, RustBertError}; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; use std::collections::HashMap; -use tch::kind::Kind::Float; use tch::nn::Init; use tch::{nn, Kind, Tensor}; @@ -251,20 +250,19 @@ impl BertModel { /// /// ```no_run /// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::Int64; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// # let bert_model: BertModel = BertModel::new(&vs.root(), &config); /// let (batch_size, sequence_length) = (64, 128); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// /// let model_output = no_grad(|| { @@ -296,14 +294,14 @@ impl BertModel { let (input_shape, device) = get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?; - let calc_mask = Tensor::ones(&input_shape, (Kind::Int64, device)); + let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device)); let mask = mask.unwrap_or(&calc_mask); let extended_attention_mask = match mask.dim() { 3 => mask.unsqueeze(1), 2 => { if self.is_decoder { - let seq_ids = Tensor::arange(input_shape[1], (Float, device)); + let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device)); let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[ input_shape[0], input_shape[1], @@ -336,7 +334,7 @@ impl BertModel { encoder_hidden_states_shape[0], encoder_hidden_states_shape[1], ], - (Kind::Int64, device), + (Kind::Int8, device), ), }; match encoder_mask.dim() { @@ -522,20 +520,19 @@ impl BertForMaskedLM { /// /// ```no_run /// # use rust_bert::bert::{BertForMaskedLM, BertConfig}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::Int64; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// # let bert_model = BertForMaskedLM::new(&vs.root(), &config); /// let (batch_size, sequence_length) = (64, 128); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// /// let model_output = no_grad(|| { @@ -667,20 +664,19 @@ impl BertForSequenceClassification { /// /// ```no_run /// # use rust_bert::bert::{BertForSequenceClassification, BertConfig}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::Int64; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// # let bert_model = BertForSequenceClassification::new(&vs.root(), &config); /// let (batch_size, sequence_length) = (64, 128); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); - /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); - /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); + /// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// /// let model_output = no_grad(|| { diff --git a/src/bert/encoder.rs b/src/bert/encoder.rs index ea28d5760..b88c13f45 100644 --- a/src/bert/encoder.rs +++ b/src/bert/encoder.rs @@ -108,18 +108,17 @@ impl BertLayer { /// /// ```no_run /// # use rust_bert::bert::{BertConfig, BertLayer}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Float}; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// let layer: BertLayer = BertLayer::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device)); - /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); + /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); /// /// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false)); /// ``` @@ -234,18 +233,17 @@ impl BertEncoder { /// /// ```no_run /// # use rust_bert::bert::{BertConfig, BertEncoder}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::{Int64, Float}; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device)); - /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); + /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device)); /// /// let encoder_output = /// no_grad(|| encoder.forward_t(&input_tensor, Some(&mask), None, None, false)); @@ -361,17 +359,16 @@ impl BertPooler { /// /// ```no_run /// # use rust_bert::bert::{BertConfig, BertPooler}; - /// # use tch::{nn, Device, Tensor, no_grad}; + /// # use tch::{nn, Device, Tensor, no_grad, Kind}; /// # use rust_bert::Config; /// # use std::path::Path; - /// # use tch::kind::Kind::Float; /// # let config_path = Path::new("path/to/config.json"); /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = BertConfig::from_file(config_path); /// let pooler: BertPooler = BertPooler::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device)); + /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); /// /// let pooler_output = no_grad(|| pooler.forward(&input_tensor)); /// ``` diff --git a/src/common/kind.rs b/src/common/kind.rs new file mode 100644 index 000000000..1893735c8 --- /dev/null +++ b/src/common/kind.rs @@ -0,0 +1,43 @@ +use crate::RustBertError; +use half; +use tch::{Kind, Scalar}; + +pub(crate) fn get_positive_infinity(kind: Kind) -> Result { + Ok(match kind { + Kind::Uint8 => Scalar::int(u8::MAX.into()), + Kind::Int8 => Scalar::int(i8::MAX.into()), + Kind::Int16 => Scalar::int(i16::MAX.into()), + Kind::Int => Scalar::int(i32::MAX.into()), + Kind::Int64 => Scalar::int(i64::MAX), + Kind::Half => Scalar::float(half::f16::MAX.into()), + Kind::Float => Scalar::float(f32::MAX.into()), + Kind::BFloat16 => Scalar::float(half::bf16::MAX.into()), + Kind::Double => Scalar::float(f64::MAX), + _ => { + return Err(RustBertError::ValueError(format!( + "Type not supported: attempted to get positive infinity for {:?}", + kind + ))) + } + }) +} + +pub(crate) fn get_negative_infinity(kind: Kind) -> Result { + Ok(match kind { + Kind::Uint8 => Scalar::int(u8::MIN.into()), + Kind::Int8 => Scalar::int(i8::MIN.into()), + Kind::Int16 => Scalar::int(i16::MIN.into()), + Kind::Int => Scalar::int(i32::MIN.into()), + Kind::Int64 => Scalar::int(i64::MIN), + Kind::Half => Scalar::float(half::f16::MIN.into()), + Kind::Float => Scalar::float(f32::MIN.into()), + Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()), + Kind::Double => Scalar::float(f64::MIN), + _ => { + return Err(RustBertError::ValueError(format!( + "Type not supported: attempted to get negative infinity for {:?}", + kind + ))) + } + }) +} diff --git a/src/common/mod.rs b/src/common/mod.rs index e8bebb48c..122f648e7 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,6 +3,7 @@ pub mod config; pub(crate) mod dropout; pub(crate) mod embeddings; pub mod error; +pub(crate) mod kind; pub(crate) mod linear; pub mod resources; pub(crate) mod summary;