Skip to content

Commit

Permalink
Merge pull request #62 from guillaume-be/generalized_qa_patch
Browse files Browse the repository at this point in the history
Generalized qa patch
  • Loading branch information
guillaume-be committed Jul 12, 2020
2 parents d758595 + c9262a3 commit fbd04ad
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 57 deletions.
38 changes: 38 additions & 0 deletions examples/download_all_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,42 @@ fn download_t5_small() -> failure::Fallible<()> {
Ok(())
}

fn download_roberta_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}

fn download_bert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}

fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
Expand All @@ -328,6 +364,8 @@ fn main() -> failure::Fallible<()> {
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_t5_small();
let _ = download_roberta_qa();
let _ = download_bert_qa();

Ok(())
}
6 changes: 2 additions & 4 deletions examples/question_answering_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ fn main() -> failure::Fallible<()> {
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
);
Expand Down
2 changes: 1 addition & 1 deletion src/bert/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl BertModelResources {
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model.ot",
"https://cdn.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad/rust_model.ot",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-rust_model.ot",
);
}

Expand Down
49 changes: 48 additions & 1 deletion src/pipelines/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! pre-processing, forward pass and postprocessing differs between pipelines while basic config and
//! tokenization objects don't.
//!
use crate::albert::AlbertConfig;
use crate::bart::BartConfig;
use crate::bert::BertConfig;
use crate::distilbert::DistilBertConfig;
Expand All @@ -28,10 +29,12 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
};
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;
use rust_tokenizers::preprocessing::vocab::albert_vocab::AlbertVocab;
use rust_tokenizers::preprocessing::vocab::marian_vocab::MarianVocab;
use rust_tokenizers::preprocessing::vocab::t5_vocab::T5Vocab;
use rust_tokenizers::{
BertTokenizer, BertVocab, RobertaTokenizer, RobertaVocab, TokenizedInput, TruncationStrategy,
AlbertTokenizer, BertTokenizer, BertVocab, RobertaTokenizer, RobertaVocab, TokenizedInput,
TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -46,6 +49,7 @@ pub enum ModelType {
Electra,
Marian,
T5,
Albert,
}

/// # Abstraction that holds a model configuration, can be of any of the supported models
Expand All @@ -60,6 +64,8 @@ pub enum ConfigOption {
Marian(BartConfig),
/// T5 configuration
T5(T5Config),
/// Albert configuration
Albert(AlbertConfig),
}

/// # Abstraction that holds a particular tokenizer, can be of any of the supported models
Expand All @@ -72,6 +78,8 @@ pub enum TokenizerOption {
Marian(MarianTokenizer),
/// T5 Tokenizer
T5(T5Tokenizer),
/// Albert Tokenizer
Albert(AlbertTokenizer),
}

impl ConfigOption {
Expand All @@ -83,6 +91,7 @@ impl ConfigOption {
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
ModelType::Marian => ConfigOption::Marian(BartConfig::from_file(path)),
ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)),
}
}

Expand All @@ -100,6 +109,9 @@ impl ConfigOption {
Self::Marian(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Albert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
}
}
Expand Down Expand Up @@ -128,6 +140,11 @@ impl TokenizerOption {
lower_case,
)),
ModelType::T5 => TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)),
ModelType::Albert => TokenizerOption::Albert(AlbertTokenizer::from_file(
vocab_path,
lower_case,
!lower_case,
)),
}
}

Expand All @@ -138,6 +155,7 @@ impl TokenizerOption {
Self::Roberta(_) => ModelType::Roberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::Albert(_) => ModelType::Albert,
}
}

Expand All @@ -162,6 +180,9 @@ impl TokenizerOption {
Self::T5(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Albert(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
}
}

Expand All @@ -172,6 +193,7 @@ impl TokenizerOption {
Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
Self::T5(ref tokenizer) => tokenizer.tokenize(text),
Self::Albert(ref tokenizer) => tokenizer.tokenize(text),
}
}

Expand Down Expand Up @@ -235,6 +257,16 @@ impl TokenizerOption {
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
}
}

Expand All @@ -245,6 +277,7 @@ impl TokenizerOption {
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}

Expand Down Expand Up @@ -279,6 +312,13 @@ impl TokenizerOption {
.get(T5Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*tokenizer
.vocab()
.special_values
.get(T5Vocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
}
}

Expand All @@ -299,6 +339,13 @@ impl TokenizerOption {
.get(RobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*tokenizer
.vocab()
.special_values
.get(AlbertVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::T5(_) => None,
}
Expand Down
15 changes: 15 additions & 0 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
//! # ;
//! ```

use crate::albert::AlbertForQuestionAnswering;
use crate::bert::BertForQuestionAnswering;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
Expand Down Expand Up @@ -250,6 +251,8 @@ pub enum QuestionAnsweringOption {
DistilBert(DistilBertForQuestionAnswering),
/// Roberta for Question Answering
Roberta(RobertaForQuestionAnswering),
/// Albert for Question Answering
Albert(AlbertForQuestionAnswering),
}

impl QuestionAnsweringOption {
Expand Down Expand Up @@ -289,6 +292,13 @@ impl QuestionAnsweringOption {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Albert => {
if let ConfigOption::Albert(config) = config {
QuestionAnsweringOption::Albert(AlbertForQuestionAnswering::new(p, config))
} else {
panic!("You can only supply an AlbertConfig for Albert!");
}
}
ModelType::Electra => {
panic!("QuestionAnswering not implemented for Electra!");
}
Expand All @@ -307,6 +317,7 @@ impl QuestionAnsweringOption {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Albert(_) => ModelType::Albert,
}
}

Expand All @@ -333,6 +344,10 @@ impl QuestionAnsweringOption {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
}
Self::Albert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
}
}
}
}
Expand Down
Loading

0 comments on commit fbd04ad

Please sign in to comment.