Skip to content

Commit

Permalink
Add pipelines::masked_language and codebert support (guillaume-be#282)
Browse files Browse the repository at this point in the history
* ad support for loading local moddel in SequenceClassificationConfig

* adjust config to match the SequenceClassificationConfig

* add piplines::masked_language

* add support and example for codebert

* provide an optional mask_token String field for asked_language pipline

* update example for masked_language pipeline

* codebert support revocation

* revoke support for loading local moddel

* solve conflicts

* update MaskedLanguageConfig

* fix doctest error in zero_shot_classification.rs

* MaskedLM pipeline updates

* fix multiple masked token, added test

* Updated changelog and docs

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
  • Loading branch information
Vincent-Xiao and guillaume-be committed Dec 21, 2022
1 parent a34cf9f commit dae899f
Show file tree
Hide file tree
Showing 11 changed files with 839 additions and 105 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. The format
## Added
- Addition of All-MiniLM-L6-V2 model weights
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)
- Addition of Masked Language Model pipeline, allowing to predict masked words.

## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).
Expand Down
12 changes: 9 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ repository = "https://github.com/guillaume-be/rust-bert"
documentation = "https://docs.rs/rust-bert"
license = "Apache-2.0"
readme = "README.md"
keywords = ["nlp", "deep-learning", "machine-learning", "transformers", "translation"]
keywords = [
"nlp",
"deep-learning",
"machine-learning",
"transformers",
"translation",
]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down Expand Up @@ -57,7 +63,7 @@ opt-level = 3
default = ["remote"]
doc-only = ["tch/doc-only"]
all-tests = []
remote = [ "cached-path", "dirs", "lazy_static" ]
remote = ["cached-path", "dirs", "lazy_static"]

[package.metadata.docs.rs]
features = ["doc-only"]
Expand All @@ -82,6 +88,6 @@ anyhow = "1.0.58"
csv = "1.1.6"
criterion = "0.3.6"
tokio = { version = "1.20.0", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "~0.9.0"
torch-sys = "0.9.0"
tempfile = "3.3.0"
itertools = "0.10.3"
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The tasks currently supported include:
- Part of Speech tagging
- Question-Answering
- Language Generation
- Masked Language Model
- Sentence Embeddings

<details>
Expand Down Expand Up @@ -436,6 +437,33 @@ Output:
```
</details>


<details>
<summary> <b>12. Masked Language Model </b> </summary>

Predict masked words in input sentences.
```rust
let model = MaskedLanguageModel::new(Default::default())?;

let sentences = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];

let output = model.predict(&sentences);
```
Output:
```
[
[MaskedToken { text: "college", id: 2267, score: 8.091}],
[
MaskedToken { text: "capital", id: 3007, score: 16.7249},
MaskedToken { text: "located", id: 2284, score: 9.0452}
]
]
```
</details>

## Benchmarks

For simple pipelines (sequence classification, tokens classification, question answering) the performance between Python and Rust is expected to be comparable. This is because the most expensive part of these pipeline is the language model itself, sharing a common implementation in the Torch backend. The [End-to-end NLP Pipelines in Rust](https://www.aclweb.org/anthology/2020.nlposs-1.4/) provides a benchmarks section covering all pipelines.
Expand Down
46 changes: 46 additions & 0 deletions examples/masked_language.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate anyhow;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::masked_language::{MaskedLanguageConfig, MaskedLanguageModel};
use rust_bert::resources::RemoteResource;
fn main() -> anyhow::Result<()> {
// Set-up model
let config = MaskedLanguageConfig::new(
ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT),
RemoteResource::from_pretrained(BertConfigResources::BERT),
RemoteResource::from_pretrained(BertVocabResources::BERT),
None,
true,
None,
None,
Some(String::from("<mask>")),
);

let mask_language_model = MaskedLanguageModel::new(config)?;
// Define input
let input = [
"Hello I am a <mask> student",
"Paris is the <mask> of France. It is <mask> in Europe.",
];

// Run model
let output = mask_language_model.predict(input)?;
for sentence_output in output {
println!("{:?}", sentence_output);
}

Ok(())
}
96 changes: 0 additions & 96 deletions examples/masked_language_model_bert.rs

This file was deleted.

33 changes: 33 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//! - Question-Answering
//! - Language Generation
//! - Sentence Embeddings
//! - Masked Language Model
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
Expand Down Expand Up @@ -613,6 +614,38 @@
//! # ;
//! ```
//! </details>
//! &nbsp;
//! <details>
//! <summary> <b>12. Masked Language Model </b> </summary>
//!
//! Predict masked words in input sentences.
//!```no_run
//! # use rust_bert::pipelines::masked_language::MaskedLanguageModel;
//! # fn main() -> anyhow::Result<()> {
//! let model = MaskedLanguageModel::new(Default::default())?;
//!
//! let sentences = [
//! "Hello I am a <mask> student",
//! "Paris is the <mask> of France. It is <mask> in Europe.",
//! ];
//!
//! let output = model.predict(&sentences);
//! # Ok(())
//! # }
//! ```
//! Output:
//!```no_run
//! # use rust_bert::pipelines::masked_language::MaskedToken;
//! let output = vec![
//! vec![MaskedToken { text: String::from("college"), id: 2267, score: 8.091}],
//! vec![
//! MaskedToken { text: String::from("capital"), id: 3007, score: 16.7249},
//! MaskedToken { text: String::from("located"), id: 2284, score: 9.0452}
//! ]
//! ]
//! # ;
//! ```
//! </details>
//!
//! ## Benchmarks
//!
Expand Down
108 changes: 108 additions & 0 deletions src/pipelines/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,114 @@ impl TokenizerOption {
}
}

/// Interface method
pub fn get_mask_id(&self) -> Option<i64> {
match *self {
Self::Bert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(BertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Deberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::DebertaV2(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(DeBERTaV2Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Roberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Bart(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(RobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLMRobertaVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(AlbertVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::XLNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(XLNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::ProphetNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(ProphetNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::MBart50(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(MBart50Vocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::FNet(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(FNetVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Pegasus(ref tokenizer) => Some(
*MultiThreadedTokenizer::vocab(tokenizer)
.special_values
.get(PegasusVocab::mask_value())
.expect("MASK token not found in vocabulary"),
),
Self::Marian(_) => None,
Self::M2M100(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
}
}

/// Interface method
pub fn get_mask_value(&self) -> Option<&str> {
match self {
Self::Bert(_) => Some(BertVocab::mask_value()),
Self::Deberta(_) => Some(DeBERTaVocab::mask_value()),
Self::DebertaV2(_) => Some(DeBERTaV2Vocab::mask_value()),
Self::Roberta(_) => Some(RobertaVocab::mask_value()),
Self::Bart(_) => Some(RobertaVocab::mask_value()),
Self::XLMRoberta(_) => Some(XLMRobertaVocab::mask_value()),
Self::Albert(_) => Some(AlbertVocab::mask_value()),
Self::XLNet(_) => Some(XLNetVocab::mask_value()),
Self::ProphetNet(_) => Some(ProphetNetVocab::mask_value()),
Self::MBart50(_) => Some(MBart50Vocab::mask_value()),
Self::FNet(_er) => Some(FNetVocab::mask_value()),
Self::M2M100(_) => None,
Self::Marian(_) => None,
Self::T5(_) => None,
Self::GPT2(_) => None,
Self::OpenAiGpt(_) => None,
Self::Reformer(_) => None,
Self::Pegasus(_) => None,
}
}

/// Interface method
pub fn get_bos_id(&self) -> Option<i64> {
match *self {
Expand Down
Loading

0 comments on commit dae899f

Please sign in to comment.