Skip to content

Commit

Permalink
Addition of masked LM for DistilBERT (incl. example)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Feb 13, 2020
1 parent b3cb142 commit d6716a9
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.2.0"
version = "0.3.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
default-run = "rust-bert"
Expand All @@ -21,7 +21,7 @@ name = "convert-tensor"
path = "src/convert-tensor.rs"

[dependencies]
rust_transformers = "0.1.0"
rust_transformers = "0.2.0"
tch = "0.1.6"
serde_json = "1.0.45"
serde = {version = "1.0.104", features = ["derive"]}
Expand Down
70 changes: 70 additions & 0 deletions examples/distilbert_masked_lm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::path::PathBuf;
use tch::{Device, Tensor, nn, no_grad};
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
use rust_transformers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_transformers::bert_tokenizer::BertTokenizer;
use rust_transformers::preprocessing::vocab::base_vocab::Vocab;

extern crate failure;
extern crate dirs;

fn main() -> failure::Fallible<()> {

// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
let config = DistilBertConfig::from_file(config_path);
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;

// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();

// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);


// Forward pass
let (output, _, _) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});

// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"

Ok(())
}
45 changes: 40 additions & 5 deletions src/distilbert/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ pub struct DistilBertConfig {
pub dim: i64,
pub dropout: f64,
pub hidden_dim: i64,
pub id2label: HashMap<i32, String>,
pub id2label: Option<HashMap<i32, String>>,
pub initializer_range: f32,
pub is_decoder: bool,
pub label2id: HashMap<String, i32>,
pub is_decoder: Option<bool>,
pub label2id: Option<HashMap<String, i32>>,
pub max_position_embeddings: i64,
pub n_heads: i64,
pub n_layers: i64,
pub num_labels: i64,
pub output_attentions: bool,
pub output_hidden_states: bool,
pub output_past: bool,
pub output_past: Option<bool>,
pub qa_dropout: f32,
pub seq_classif_dropout: f64,
pub sinusoidal_pos_embds: bool,
pub tie_weights_: bool,
pub torchscript: bool,
pub use_bfloat16: bool,
pub use_bfloat16: Option<bool>,
pub vocab_size: i64,
}

Expand Down Expand Up @@ -138,3 +138,38 @@ impl DistilBertModelClassifier {
Ok((output, all_hidden_states, all_attentions))
}
}

pub struct DistilBertModelMaskedLM {
distil_bert_model: DistilBertModel,
vocab_transform: nn::Linear,
vocab_layer_norm: nn::LayerNorm,
vocab_projector: nn::Linear,
}

impl DistilBertModelMaskedLM {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());

DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
}

pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};

let output = output
.apply_t(&self.vocab_transform, train)
.gelu()
.apply_t(&self.vocab_layer_norm, train)
.apply_t(&self.vocab_projector, train);

Ok((output, all_hidden_states, all_attentions))
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod distilbert;

pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier};
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
Expand Down
40 changes: 40 additions & 0 deletions utils/download-dependencies_distilbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_distilbert import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

config_path = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP["distilbert-base-uncased"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilbert-base-uncased"]
weights_path = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP["distilbert-base-uncased"]

target_path = Path.home() / 'rustbert' / 'distilbert'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(target_path, exist_ok=True)
shutil.copy(temp_config, target_path / 'config.json')
shutil.copy(temp_vocab, target_path / 'vocab.txt')
shutil.copy(temp_weights, target_path / 'model.bin')

weights = torch.load(temp_weights)
nps = {}
for k, v in weights.items():
nps[k] = v.cpu().numpy()

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', '+nightly', 'run', '--bin=convert-tensor', f'--manifest-path={toml_location}', '--', source,
target])
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilbert-base-uncased"]
weights_path = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP["distilbert-base-uncased-finetuned-sst-2-english"]

target_path = Path.home() / 'rustbert'
target_path = Path.home() / 'rustbert' / 'distilbert_sst2'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
Expand Down

0 comments on commit d6716a9

Please sign in to comment.