Skip to content

Commit

Permalink
Addition of primitive infinite utilities, rework BERT, BART
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 13, 2021
1 parent 13995b3 commit 9be555f
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 50 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ cached-path = "0.5.1"
lazy_static = "1.4.0"
uuid = { version = "0.8.2", features = ["v4"] }
thiserror = "1.0.26"
half = "1.7.1"

[dev-dependencies]
anyhow = "1.0.43"
Expand Down
6 changes: 3 additions & 3 deletions benches/tensor_operations_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ extern crate criterion;

use criterion::{black_box, Criterion};
use std::time::{Duration, Instant};
use tch::kind::Kind::Float;
use tch::kind::Kind;
use tch::{Device, Tensor};

fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
Expand All @@ -21,8 +21,8 @@ fn bench_tensor_ops(c: &mut Criterion) {
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand(&[32, 128, 512], (Float, Device::cuda_if_available()));
let weights = Tensor::rand(&[512, 512], (Float, Device::cuda_if_available()));
let input = Tensor::rand(&[32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand(&[512, 512], (Kind::Float, Device::cuda_if_available()));

let _ = &input.matmul(&weights);
c.bench_function("Matrix multiply ", |b| {
Expand Down
3 changes: 1 addition & 2 deletions src/bart/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};

#[derive(Debug)]
Expand Down Expand Up @@ -164,7 +163,7 @@ impl BartAttention {
attention_weights.view([bs * self.num_heads, target_length, source_length]);
};

attention_weights = attention_weights.softmax(-1, Float);
attention_weights = attention_weights.softmax(-1, attention_weights.kind());

let saved_attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_length, source_length)))
Expand Down
15 changes: 9 additions & 6 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
Expand All @@ -33,7 +34,6 @@ use rust_tokenizers::vocab::{RobertaVocab, Vocab};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Device, Kind, Tensor};

Expand Down Expand Up @@ -271,9 +271,12 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>) -> Tensor
.unsqueeze(1)
.unsqueeze(1)
.expand(&[batch_size, 1, target_length, source_length], true)
.totype(Kind::Float);
.totype(mask.kind());
let inverted_mask: Tensor = 1 - expanded_mask;
inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), f64::NEG_INFINITY)
inverted_mask.masked_fill(
&inverted_mask.to_kind(Kind::Bool),
get_negative_infinity(inverted_mask.kind()).unwrap(),
)
}

pub(crate) fn _prepare_decoder_attention_mask(
Expand Down Expand Up @@ -308,9 +311,9 @@ pub(crate) fn _prepare_decoder_attention_mask(
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids
.ne(pad_token_id)
.sum_dim_intlist(&[-1], true, Int64)
.sum_dim_intlist(&[-1], true, Kind::Int64)
- 1;
let output = input_ids.empty_like().to_kind(Int64);
let output = input_ids.empty_like().to_kind(Kind::Int64);
output
.select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
Expand Down Expand Up @@ -812,7 +815,7 @@ impl BartForSequenceClassification {
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum_dim_intlist(&[1], true, Int64);
let reshape = eos_mask.sum_dim_intlist(&[1], true, Kind::Bool);
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])
Expand Down
7 changes: 3 additions & 4 deletions src/bart/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
// limitations under the License.

use std::borrow::Borrow;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
use tch::{nn, Kind, Tensor};

/// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption {
Expand Down Expand Up @@ -67,7 +66,7 @@ impl LearnedPositionalEmbedding {
let positions = Tensor::arange_start(
past_key_values_length,
past_key_values_length + sequence_length,
(Int64, input.device()),
(Kind::Int64, input.device()),
) + self.offset;
positions.apply(&self.embedding)
}
Expand Down Expand Up @@ -102,7 +101,7 @@ impl SinusoidalPositionalEmbedding {
let positions = Tensor::arange_start(
past_key_values_length,
past_key_values_length + sequence_length,
(Int64, input.device()),
(Kind::Int64, input.device()),
);
positions.apply(&self.embedding)
}
Expand Down
5 changes: 3 additions & 2 deletions src/bert/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use crate::bert::bert_model::BertConfig;
use crate::common::activations::TensorFunction;
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};

#[derive(Debug)]
Expand Down Expand Up @@ -124,7 +123,9 @@ impl BertSelfAttention {
query_layer.matmul(&key_layer.transpose(-1, -2))
};

let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let weights = scores
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size);

if !self.output_attentions {
Expand Down
40 changes: 18 additions & 22 deletions src/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::Init;
use tch::{nn, Kind, Tensor};

Expand Down Expand Up @@ -251,20 +250,19 @@ impl<T: BertEmbedding> BertModel<T> {
///
/// ```no_run
/// # use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
Expand Down Expand Up @@ -296,14 +294,14 @@ impl<T: BertEmbedding> BertModel<T> {
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;

let calc_mask = Tensor::ones(&input_shape, (Kind::Int64, device));
let calc_mask = Tensor::ones(&input_shape, (Kind::Int8, device));
let mask = mask.unwrap_or(&calc_mask);

let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
Expand Down Expand Up @@ -336,7 +334,7 @@ impl<T: BertEmbedding> BertModel<T> {
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
(Kind::Int64, device),
(Kind::Int8, device),
),
};
match encoder_mask.dim() {
Expand Down Expand Up @@ -522,20 +520,19 @@ impl BertForMaskedLM {
///
/// ```no_run
/// # use rust_bert::bert::{BertForMaskedLM, BertConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model = BertForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
Expand Down Expand Up @@ -667,20 +664,19 @@ impl BertForSequenceClassification {
///
/// ```no_run
/// # use rust_bert::bert::{BertForSequenceClassification, BertConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Kind::Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Kind::Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let model_output = no_grad(|| {
Expand Down
19 changes: 8 additions & 11 deletions src/bert/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,17 @@ impl BertLayer {
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertLayer};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Float};
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
///
/// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
/// ```
Expand Down Expand Up @@ -234,18 +233,17 @@ impl BertEncoder {
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertEncoder};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Float};
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
///
/// let encoder_output =
/// no_grad(|| encoder.forward_t(&input_tensor, Some(&mask), None, None, false));
Expand Down Expand Up @@ -361,17 +359,16 @@ impl BertPooler {
///
/// ```no_run
/// # use rust_bert::bert::{BertConfig, BertPooler};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Float;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Float, device));
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
///
/// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
/// ```
Expand Down
43 changes: 43 additions & 0 deletions src/common/kind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use crate::RustBertError;
use half;
use tch::{Kind, Scalar};

pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MAX.into()),
Kind::Int8 => Scalar::int(i8::MAX.into()),
Kind::Int16 => Scalar::int(i16::MAX.into()),
Kind::Int => Scalar::int(i32::MAX.into()),
Kind::Int64 => Scalar::int(i64::MAX),
Kind::Half => Scalar::float(half::f16::MAX.into()),
Kind::Float => Scalar::float(f32::MAX.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MAX.into()),
Kind::Double => Scalar::float(f64::MAX),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get positive infinity for {:?}",
kind
)))
}
})
}

pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),
Kind::Int8 => Scalar::int(i8::MIN.into()),
Kind::Int16 => Scalar::int(i16::MIN.into()),
Kind::Int => Scalar::int(i32::MIN.into()),
Kind::Int64 => Scalar::int(i64::MIN),
Kind::Half => Scalar::float(half::f16::MIN.into()),
Kind::Float => Scalar::float(f32::MIN.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()),
Kind::Double => Scalar::float(f64::MIN),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get negative infinity for {:?}",
kind
)))
}
})
}
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod config;
pub(crate) mod dropout;
pub(crate) mod embeddings;
pub mod error;
pub(crate) mod kind;
pub(crate) mod linear;
pub mod resources;
pub(crate) mod summary;
Expand Down

0 comments on commit 9be555f

Please sign in to comment.