Skip to content

Commit

Permalink
Updated GPT-Neo, working half precision greedy generation
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 26, 2021
1 parent 1fdd775 commit 72fabcd
Show file tree
Hide file tree
Showing 20 changed files with 141 additions and 72 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "~6.2.4"
tch = "~0.5.0"
tch = { version = "0.5.0", path = "E:/Coding/tch-rs" }
serde_json = "1.0.66"
serde = { version = "1.0.129", features = ["derive"] }
dirs = "3.0.2"
Expand All @@ -73,5 +73,5 @@ half = "1.7.1"
anyhow = "1.0.43"
csv = "1.1.6"
criterion = "0.3.5"
torch-sys = "0.5.0"
torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" }
tempfile = "3.2.0"
12 changes: 7 additions & 5 deletions examples/generation_gpt_neo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ use tch::Device;
fn main() -> anyhow::Result<()> {
// Set-up model resources
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoConfigResources::GPT_NEO_1_3B,
GptNeoConfigResources::GPT_NEO_125M,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoVocabResources::GPT_NEO_1_3B,
GptNeoVocabResources::GPT_NEO_125M,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoMergesResources::GPT_NEO_1_3B,
GptNeoMergesResources::GPT_NEO_125M,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
GptNeoModelResources::GPT_NEO_1_3B,
GptNeoModelResources::GPT_NEO_125M,
));
let generate_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
Expand All @@ -52,7 +52,9 @@ fn main() -> anyhow::Result<()> {
..Default::default()
};

let model = TextGenerationModel::new(generate_config)?;
let mut model = TextGenerationModel::new(generate_config)?;
// model.half();
model.set_device(Device::cuda_if_available());

let input_context_1 = "It was a very nice and sunny";
let input_context_2 = "It was a gloom winter night, and";
Expand Down
3 changes: 2 additions & 1 deletion src/albert/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ impl AlbertSelfAttention {
self.hidden_size,
));

let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);

if !self.output_attentions {
Expand Down
3 changes: 3 additions & 0 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,9 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
20 changes: 0 additions & 20 deletions src/common/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,6 @@ use crate::RustBertError;
use half;
use tch::{Kind, Scalar};

pub(crate) fn get_positive_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MAX.into()),
Kind::Int8 => Scalar::int(i8::MAX.into()),
Kind::Int16 => Scalar::int(i16::MAX.into()),
Kind::Int => Scalar::int(i32::MAX.into()),
Kind::Int64 => Scalar::int(i64::MAX),
Kind::Half => Scalar::float(half::f16::MAX.into()),
Kind::Float => Scalar::float(f32::MAX.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MAX.into()),
Kind::Double => Scalar::float(f64::MAX),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get positive infinity for {:?}",
kind
)))
}
})
}

pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),
Expand Down
4 changes: 2 additions & 2 deletions src/common/summary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::xlnet::XLNetConfig;
use crate::RustBertError;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::{nn, Kind, Tensor};
use tch::{nn, Tensor};

#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
Expand Down Expand Up @@ -132,7 +132,7 @@ impl SequenceSummary {
let mut output = match self.summary_type {
SummaryType::last => hidden_states.select(1, -1),
SummaryType::first => hidden_states.select(1, 0),
SummaryType::mean => hidden_states.mean_dim(&[1], false, Kind::Float),
SummaryType::mean => hidden_states.mean_dim(&[1], false, hidden_states.kind()),
SummaryType::cls_index => {
let cls_index = if let Some(cls_index_value) = cls_index {
let mut expand_dim = vec![-1i64; cls_index_value.dim() - 1];
Expand Down
3 changes: 3 additions & 0 deletions src/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,9 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
31 changes: 12 additions & 19 deletions src/gpt_neo/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use crate::gpt_neo::gpt_neo_model::AttentionLayerType;
use crate::gpt_neo::GptNeoConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Device, Kind, Tensor};

#[derive(Debug)]
Expand Down Expand Up @@ -207,23 +206,28 @@ pub(crate) trait GptNeoAttentionUtils {
key: &Tensor,
value: &Tensor,
causal_mask: &Tensor,
masked_bias: &Tensor,
attention_dropout: &Dropout,
attention_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Tensor) {
let mut attention_weights = query
.matmul(&key.transpose(-1, -2))
.where_self(causal_mask, &masked_bias.to_kind(query.kind()));
let query = query.to_kind(Kind::Float);
let key = key.to_kind(Kind::Float);

let attention_weights = query.matmul(&key.transpose(-1, -2));
let mut attention_weights = attention_weights.where_self(
causal_mask,
&Tensor::of_slice(&[-1e9f32]).to_device(attention_weights.device()),
);

if let Some(attention_mask_value) = attention_mask {
attention_weights = attention_weights + attention_mask_value;
};

attention_weights = attention_weights
.softmax(-1, Kind::Float)
let attention_weights2 = attention_weights
.softmax(-1, attention_weights.kind())
.to_kind(value.kind())
.apply_t(attention_dropout, train);
let attention_output = attention_weights.matmul(value);
let attention_output = attention_weights2.matmul(value);
(attention_output, attention_weights)
}
}
Expand All @@ -236,7 +240,6 @@ pub struct GptNeoSelfAttention {
attention_dropout: Dropout,
resid_dropout: Dropout,
bias: Tensor,
masked_bias: Tensor,
num_heads: i64,
head_dim: i64,
output_attentions: bool,
Expand All @@ -259,8 +262,6 @@ impl GptNeoSelfAttention {

let bias = p.var_copy("bias", &bias_value);

let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9));

let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);

Expand Down Expand Up @@ -306,7 +307,6 @@ impl GptNeoSelfAttention {
attention_dropout,
resid_dropout,
bias,
masked_bias,
num_heads,
head_dim,
output_attentions,
Expand Down Expand Up @@ -357,7 +357,6 @@ impl GptNeoSelfAttention {
&key,
&value,
&causal_mask,
&self.masked_bias,
&self.attention_dropout,
attention_mask,
train,
Expand All @@ -384,7 +383,6 @@ pub struct GptNeoLocalSelfAttention {
out_proj: nn::Linear,
attention_dropout: Dropout,
resid_dropout: Dropout,
masked_bias: Tensor,
num_heads: i64,
head_dim: i64,
window_size: i64,
Expand All @@ -401,8 +399,6 @@ impl GptNeoLocalSelfAttention {
{
let p = p.borrow();

let masked_bias = p.var("masked_bias", &[1], Init::Const(-1e9));

let attention_dropout = Dropout::new(config.attention_dropout);
let resid_dropout = Dropout::new(config.resid_dropout);

Expand Down Expand Up @@ -449,7 +445,6 @@ impl GptNeoLocalSelfAttention {
out_proj,
attention_dropout,
resid_dropout,
masked_bias,
num_heads,
head_dim,
window_size,
Expand Down Expand Up @@ -523,7 +518,6 @@ impl GptNeoLocalSelfAttention {
&key,
&value,
attention_mask,
&self.masked_bias,
&self.attention_dropout,
None,
train,
Expand All @@ -539,7 +533,6 @@ impl GptNeoLocalSelfAttention {
} else {
None
};

Ok((attention_output, attention_weights))
}
}
Expand Down
21 changes: 12 additions & 9 deletions src/gpt_neo/gpt_neo_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,6 @@ impl GptNeoModel {

let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());

let global_attention_mask = attention_mask.map(|attention_mask_value| {
let global_attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
(1 - global_attention_mask) * -1e4
});

let local_attention_mask = GptNeoModel::create_local_attention_mask(
batch_size,
full_sequence_length,
Expand All @@ -358,12 +350,20 @@ impl GptNeoModel {
let input_embeds = input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let position_embeds = position_ids.apply(&self.position_embeddings);

let global_attention_mask = attention_mask.map(|attention_mask_value| {
let global_attention_mask = attention_mask_value
.view([batch_size, -1])
.unsqueeze(1)
.unsqueeze(1);
let global_attention_mask = global_attention_mask.to_kind(position_embeds.kind());
(1 - global_attention_mask) * -1e4
});

let mut hidden_state = input_embeds + position_embeds;
if let Some(token_type_ids) = token_type_ids {
hidden_state = hidden_state + token_type_ids.apply(&self.word_embeddings);
};
hidden_state = hidden_state.apply_t(&self.dropout, train);

let mut output_shape = input_shape;
output_shape.push(*hidden_state.size().last().unwrap());

Expand Down Expand Up @@ -711,6 +711,9 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
3 changes: 3 additions & 0 deletions src/m2m_100/m2m_100_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,9 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
3 changes: 3 additions & 0 deletions src/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
3 changes: 3 additions & 0 deletions src/mbart/mbart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,9 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
3 changes: 3 additions & 0 deletions src/openai_gpt/openai_gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,9 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
3 changes: 3 additions & 0 deletions src/pegasus/pegasus_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
Expand Down
Loading

0 comments on commit 72fabcd

Please sign in to comment.