-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Addition of masked LM for DistilBERT (incl. example)
- Loading branch information
1 parent
b3cb142
commit d6716a9
Showing
7 changed files
with
155 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
use std::path::PathBuf; | ||
use tch::{Device, Tensor, nn, no_grad}; | ||
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig}; | ||
use rust_transformers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy}; | ||
use rust_transformers::bert_tokenizer::BertTokenizer; | ||
use rust_transformers::preprocessing::vocab::base_vocab::Vocab; | ||
|
||
extern crate failure; | ||
extern crate dirs; | ||
|
||
fn main() -> failure::Fallible<()> { | ||
|
||
// Resources paths | ||
let mut home: PathBuf = dirs::home_dir().unwrap(); | ||
home.push("rustbert"); | ||
home.push("distilbert"); | ||
let config_path = &home.as_path().join("config.json"); | ||
let vocab_path = &home.as_path().join("vocab.txt"); | ||
let weights_path = &home.as_path().join("model.ot"); | ||
|
||
// Set-up masked LM model | ||
let device = Device::Cpu; | ||
let mut vs = nn::VarStore::new(device); | ||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap()); | ||
let config = DistilBertConfig::from_file(config_path); | ||
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); | ||
vs.load(weights_path)?; | ||
|
||
// Define input | ||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; | ||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); | ||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); | ||
let mut tokenized_input = tokenized_input. | ||
iter(). | ||
map(|input| input.token_ids.clone()). | ||
map(|mut input| { | ||
input.extend(vec![0; max_len - input.len()]); | ||
input | ||
}). | ||
collect::<Vec<_>>(); | ||
|
||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 | ||
tokenized_input[0][4] = 103; | ||
tokenized_input[1][6] = 103; | ||
let tokenized_input = tokenized_input. | ||
iter(). | ||
map(|input| | ||
Tensor::of_slice(&(input))). | ||
collect::<Vec<_>>(); | ||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); | ||
|
||
|
||
// Forward pass | ||
let (output, _, _) = no_grad(|| { | ||
distil_bert_model | ||
.forward_t(Some(input_tensor), None, None, false) | ||
.unwrap() | ||
}); | ||
|
||
// Print masked tokens | ||
let index_1 = output.get(0).get(4).argmax(0, false); | ||
let index_2 = output.get(1).get(6).argmax(0, false); | ||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); | ||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); | ||
|
||
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing" | ||
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples" | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
pub mod distilbert; | ||
|
||
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier}; | ||
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM}; | ||
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from transformers import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP | ||
from transformers.tokenization_distilbert import PRETRAINED_VOCAB_FILES_MAP | ||
from transformers.file_utils import get_from_cache | ||
from pathlib import Path | ||
import shutil | ||
import os | ||
import numpy as np | ||
import torch | ||
import subprocess | ||
|
||
config_path = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP["distilbert-base-uncased"] | ||
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["distilbert-base-uncased"] | ||
weights_path = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP["distilbert-base-uncased"] | ||
|
||
target_path = Path.home() / 'rustbert' / 'distilbert' | ||
|
||
temp_config = get_from_cache(config_path) | ||
temp_vocab = get_from_cache(vocab_path) | ||
temp_weights = get_from_cache(weights_path) | ||
|
||
os.makedirs(target_path, exist_ok=True) | ||
shutil.copy(temp_config, target_path / 'config.json') | ||
shutil.copy(temp_vocab, target_path / 'vocab.txt') | ||
shutil.copy(temp_weights, target_path / 'model.bin') | ||
|
||
weights = torch.load(temp_weights) | ||
nps = {} | ||
for k, v in weights.items(): | ||
nps[k] = v.cpu().numpy() | ||
|
||
np.savez(target_path / 'model.npz', **nps) | ||
|
||
source = str(target_path / 'model.npz') | ||
target = str(target_path / 'model.ot') | ||
|
||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() | ||
|
||
subprocess.call( | ||
['cargo', '+nightly', 'run', '--bin=convert-tensor', f'--manifest-path={toml_location}', '--', source, | ||
target]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters