Skip to content

Commit

Permalink
Generalization of input types for pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Oct 11, 2020
1 parent 426430a commit 97ee8ee
Show file tree
Hide file tree
Showing 31 changed files with 248 additions and 284 deletions.
2 changes: 1 addition & 1 deletion benches/squad_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn bench_squad(c: &mut Criterion) {
}
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json");
let mut qa_inputs = squad_processor(squad_path);
qa_inputs.truncate(1000);
Expand Down
3 changes: 1 addition & 2 deletions examples/albert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ fn main() -> anyhow::Result<()> {

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)

let tokenized_input =
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/distilbert_masked_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
2 changes: 1 addition & 1 deletion examples/electra_discriminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = MultiThreadedTokenizer::encode_list(
&tokenizer,
input.to_vec(),
&input,
128,
&TruncationStrategy::LongestFirst,
0,
Expand Down
3 changes: 1 addition & 2 deletions examples/electra_masked_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ fn main() -> anyhow::Result<()> {

// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/openai_gpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ fn main() -> anyhow::Result<()> {

// Define input
let input = ["Wondering what the next word will"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/roberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ fn main() -> anyhow::Result<()> {
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
3 changes: 1 addition & 2 deletions examples/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ fn main() -> anyhow::Result<()> {

// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
Expand Down
145 changes: 48 additions & 97 deletions src/pipelines/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use rust_tokenizers::tokenizer::{
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, MarianVocab, RobertaVocab, T5Vocab, XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::{Mask, Offset, OffsetSize, TokenizedInput};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
Expand Down Expand Up @@ -273,7 +273,7 @@ impl TokenizerOption {
/// Interface method
pub fn encode_list(
&self,
text_list: Vec<&str>,
text_list: &[&str],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
Expand Down Expand Up @@ -330,7 +330,7 @@ impl TokenizerOption {
/// Interface method for pair encoding
pub fn encode_pair_list(
&self,
text_pair_list: Vec<(&str, &str)>,
text_pair_list: &[(&str, &str)],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
Expand Down Expand Up @@ -400,110 +400,61 @@ impl TokenizerOption {
/// Interface method to build input with special tokens
pub fn build_input_with_special_tokens(
&self,
tokens_1: Vec<i64>,
tokens_2: Option<Vec<i64>>,
offsets_1: Vec<Option<Offset>>,
offsets_2: Option<Vec<Option<Offset>>>,
original_offsets_1: Vec<Vec<OffsetSize>>,
original_offsets_2: Option<Vec<Vec<OffsetSize>>>,
mask_1: Vec<Mask>,
mask_2: Option<Vec<Mask>>,
token_ids_with_offsets_1: TokenIdsWithOffsets,
token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
) -> TokenizedInput {
let (token_ids, segment_ids, special_tokens_mask, token_offsets, reference_offsets, mask) =
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
};
let token_ids_with_special_tokens = match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
token_ids: token_ids_with_special_tokens.token_ids,
segment_ids: token_ids_with_special_tokens.segment_ids,
special_tokens_mask: token_ids_with_special_tokens.special_tokens_mask,
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets,
reference_offsets,
mask,
token_offsets: token_ids_with_special_tokens.token_offsets,
reference_offsets: token_ids_with_special_tokens.reference_offsets,
mask: token_ids_with_special_tokens.mask,
}
}

/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ impl ConversationModel {
.map(|c| &c.history)
.collect_vec();

let prompt_ids = self.encode_prompts(texts.as_slice());
let prompt_ids = self.encode_prompts(texts.as_ref());
let input_tensor = self.concat_input_history(prompt_ids, history);
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
Expand Down Expand Up @@ -791,7 +791,7 @@ impl ConversationModel {

fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
// Encode the user prompt into token ids
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
let tokens = self.model.get_tokenizer().tokenize_list(texts);

tokens
.into_iter()
Expand Down
Loading

0 comments on commit 97ee8ee

Please sign in to comment.