Skip to content

Commit

Permalink
Lib clean-up and doc landing page
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Mar 21, 2020
1 parent 24335a5 commit f17b0d7
Show file tree
Hide file tree
Showing 24 changed files with 131 additions and 63 deletions.
11 changes: 9 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[package]
name = "rust-bert"
version = "0.5.2"
version = "0.5.3"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Native (Distil)BERT implementation for Rust"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
repository = "https://github.com/guillaume-be/rust-bert"
license = "Apache-2.0"
readme = "README.md"
Expand All @@ -19,6 +19,13 @@ crate-type = ["lib"]
[[bin]]
name = "convert-tensor"
path = "src/convert-tensor.rs"
doc = false

[features]
doc-only = ["tch/doc-only"]

[package.metadata.docs.rs]
features = [ "doc-only" ]

[dependencies]
rust_tokenizers = "2.0.3"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Extractive question answering from a given question and context. DistilBERT mode
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");

let answers = qa_model.predict(vec!(QaInput { question, context }), 1, 32);
let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
```

Output:
Expand All @@ -56,7 +56,7 @@ This may impact the results and it is recommended to submit prompts of similar l
let input_context_1 = "The dog";
let input_context_2 = "The cat was";

let output = model.generate(Some(input_context_1, input_context_2), 0, 30, true, false,
let output = model.generate(Some(vec!(input_context_1, input_context_2)), 0, 30, true, false,
5, 1.2, 0, 0.9, 1.0, 1.0, 3, 3, None);
```
Example output:
Expand Down
3 changes: 2 additions & 1 deletion examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use failure::err_msg;
use rust_bert::{BertConfig, BertForMaskedLM, Config};
use rust_bert::bert::bert::{BertConfig, BertForMaskedLM};
use rust_bert::Config;


fn main() -> failure::Fallible<()> {
Expand Down
4 changes: 3 additions & 1 deletion examples/distilbert_masked_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, Trunc
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use failure::err_msg;
use rust_bert::{Config, DistilBertConfig, DistilBertModelMaskedLM};
use rust_bert::distilbert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
use rust_bert::Config;


fn main() -> failure::Fallible<()> {

Expand Down
2 changes: 1 addition & 1 deletion examples/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::{GPT2Generator, LanguageGenerator};
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator};


fn main() -> failure::Fallible<()> {
Expand Down
3 changes: 2 additions & 1 deletion examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use failure::err_msg;
use rust_bert::{Gpt2Config, Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::gpt2::gpt2::{Gpt2Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::Config;


fn main() -> failure::Fallible<()> {
Expand Down
4 changes: 2 additions & 2 deletions examples/ner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ extern crate dirs;

use std::path::PathBuf;
use tch::Device;
use rust_bert::NERModel;
use failure::err_msg;
use rust_bert::pipelines::ner::NERModel;


fn main() -> failure::Fallible<()> {
Expand Down Expand Up @@ -48,7 +48,7 @@ fn main() -> failure::Fallible<()> {
];

// Run model
let output = ner_model.predict(input.to_vec());
let output = ner_model.predict(&input);
for entity in output {
println!("{:?}", entity);
}
Expand Down
4 changes: 3 additions & 1 deletion examples/openai_gpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::{Gpt2Config, Config, OpenAIGPTLMHeadModel, LMHeadModel};
use failure::err_msg;
use rust_bert::gpt2::gpt2::{Gpt2Config, LMHeadModel};
use rust_bert::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use rust_bert::Config;


fn main() -> failure::Fallible<()> {
Expand Down
2 changes: 1 addition & 1 deletion examples/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::{QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};


fn main() -> failure::Fallible<()> {
Expand Down
4 changes: 3 additions & 1 deletion examples/roberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use failure::err_msg;
use rust_bert::{BertConfig, RobertaForMaskedLM, Config};
use rust_bert::bert::bert::BertConfig;
use rust_bert::roberta::roberta::RobertaForMaskedLM;
use rust_bert::Config;


fn main() -> failure::Fallible<()> {
Expand Down
4 changes: 2 additions & 2 deletions examples/sentiment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ extern crate dirs;

use std::path::PathBuf;
use tch::Device;
use rust_bert::SentimentClassifier;
use failure::err_msg;
use rust_bert::pipelines::sentiment::SentimentClassifier;


fn main() -> failure::Fallible<()> {
Expand Down Expand Up @@ -49,7 +49,7 @@ fn main() -> failure::Fallible<()> {
];

// Run model
let output = sentiment_classifier.predict(input.to_vec());
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}
Expand Down
2 changes: 1 addition & 1 deletion examples/squad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::path::PathBuf;
use tch::Device;
use std::env;
use failure::err_msg;
use rust_bert::{QuestionAnsweringModel, squad_processor};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};


fn main() -> failure::Fallible<()> {
Expand Down
4 changes: 2 additions & 2 deletions src/bert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod bert;
pub mod embeddings;
pub mod attention;
pub mod encoder;
mod attention;
mod encoder;
88 changes: 68 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,71 @@
mod distilbert;
mod bert;
mod roberta;
mod openai_gpt;
mod gpt2;
mod common;
mod pipelines;

pub use common::config::Config;
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering};
//! Ready-to-use NLP pipelines and Transformer-based models
//!
//! Rust native Transformer-based models implementation. Port of the [Transformers](https://github.com/huggingface/transformers) library, using the tch-rs crate and pre-processing from rust-tokenizers.
//! Supports multithreaded tokenization and GPU inference. This repository exposes the model base architecture, task-specific heads (see below) and ready-to-use pipelines.
//!
//! # Quick Start
//!
//! This crate can be used in two different ways:
//! - Ready-to-use NLP pipelines for Sentiment Analysis, Named Entity Recognition, Question-Answering or Language Generation. More information on these can be found in the `pipelines` module.
//! ```no_run
//! use tch::Device;
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//!# use std::path::PathBuf;
//!
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert-qa");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//!
//! let device = Device::cuda_if_available();
//! let qa_model = QuestionAnsweringModel::new(vocab_path,
//! config_path,
//! weights_path, device)?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! # Ok(())
//! # }
//! ```
//! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust
//!
//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**
//! :-----:|:-----:|:-----:|:-----:|:-----:|:-----:
//! Masked LM|✅ |✅ |✅ | | |
//! Sequence classification|✅ |✅ |✅| | |
//! Token classification|✅ |✅ | ✅| | |
//! Question answering|✅ |✅ |✅| | |
//! Multiple choices| |✅ |✅| | |
//! Next token prediction| | | |✅|✅|
//! Natural Language Generation| | | |✅|✅|
//!
//! # Loading pre-trained models
//!
//! The architectures defined in this crate are compatible with model trained in the [Transformers](https://github.com/huggingface/transformers) library.
//! The model configuration and vocabulary are downloaded directly from Huggingface's repository.
//! The model weights need to be converter to a binary format that can be read by Libtorch (the original .bin files are pickles and cannot be used directly).
//! A Python script for downloading the required files & running the necessary steps is provided for all models classes in this library.
//! Further models can be loaded by extending the python scripts to point to the desired model.
//!
//!
//! 1. Compile the package: cargo build --release
//! 2. Download the model files & perform necessary conversions
//! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release
//!

pub use bert::bert::BertConfig;
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};

pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};

pub use gpt2::gpt2::{Gpt2Config, Gpt2Model, GPT2LMHeadModel, LMHeadModel};
pub use openai_gpt::openai_gpt::{OpenAiGptModel, OpenAIGPTLMHeadModel};
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
mod common;
pub mod pipelines;

pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use pipelines::ner::{Entity, NERModel};
pub use pipelines::question_answering::{QaInput, QuestionAnsweringModel, squad_processor};
pub use pipelines::generation::{OpenAIGenerator, GPT2Generator, LanguageGenerator};
pub use common::config::Config;
3 changes: 1 addition & 2 deletions src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::gpt2::gpt2::LMHeadModel;
use crate::gpt2::gpt2::{LMHeadModel, Gpt2Config, GPT2LMHeadModel};
use tch::{Tensor, Device, nn, no_grad};
use rust_tokenizers::{Tokenizer, OpenAiGptTokenizer, OpenAiGptVocab, Vocab, TruncationStrategy, Gpt2Tokenizer, Gpt2Vocab};
use crate::openai_gpt::openai_gpt::OpenAIGPTLMHeadModel;
use std::path::Path;
use crate::{Gpt2Config, GPT2LMHeadModel};
use crate::common::config::Config;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use tch::kind::Kind::{Int64, Float, Bool};
Expand Down
6 changes: 3 additions & 3 deletions src/pipelines/ner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::{BertForTokenClassification, BertConfig};
use std::collections::HashMap;
use crate::common::config::Config;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::bert::{BertForTokenClassification, BertConfig};


#[derive(Debug)]
Expand Down Expand Up @@ -67,8 +67,8 @@ impl NERModel {
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}

pub fn predict(&self, input: Vec<&str>) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input);
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.bert_sequence_classifier
.forward_t(Some(input_tensor.copy()),
Expand Down
4 changes: 2 additions & 2 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ use std::path::{Path, PathBuf};
use rust_tokenizers::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use std::cmp::min;
use crate::{DistilBertForQuestionAnswering, DistilBertConfig};
use tch::nn::VarStore;
use crate::common::config::Config;
use tch::kind::Kind::Float;
use std::fs;
use crate::distilbert::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig};
use crate::Config;

pub struct QaInput {
pub question: String,
Expand Down
6 changes: 3 additions & 3 deletions src/pipelines/sentiment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


use rust_tokenizers::bert_tokenizer::BertTokenizer;
use crate::distilbert::distilbert::{DistilBertModelClassifier, DistilBertConfig};
use std::path::Path;
use tch::{Device, Tensor, Kind, no_grad};
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::common::config::Config;
use crate::distilbert::distilbert::{DistilBertConfig, DistilBertModelClassifier};


#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -68,8 +68,8 @@ impl SentimentClassifier {
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}

pub fn predict(&self, input: Vec<&str>) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input);
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.distil_bert_classifier
.forward_t(Some(input_tensor),
Expand Down
6 changes: 4 additions & 2 deletions tests/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ extern crate dirs;
use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::{NERModel, BertConfig, BertForMaskedLM, Config, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering};
use rust_bert::bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering};
use rust_bert::Config;
use rust_bert::pipelines::ner::NERModel;

#[test]
fn bert_masked_lm() -> failure::Fallible<()> {
Expand Down Expand Up @@ -315,7 +317,7 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
];

// Run model
let output = ner_model.predict(input.to_vec());
let output = ner_model.predict(&input);


assert_eq!(output.len(), 4);
Expand Down
7 changes: 5 additions & 2 deletions tests/distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::{SentimentClassifier, SentimentPolarity, DistilBertConfig, DistilBertModelMaskedLM, Config, DistilBertForQuestionAnswering, DistilBertForTokenClassification, QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::sentiment::{SentimentClassifier, SentimentPolarity};
use rust_bert::distilbert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification};
use rust_bert::Config;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};

extern crate failure;
extern crate dirs;
Expand Down Expand Up @@ -32,7 +35,7 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];

let output = sentiment_classifier.predict(input.to_vec());
let output = sentiment_classifier.predict(&input);

assert_eq!(output.len(), 3 as usize);
assert_eq!(output[0].polarity, SentimentPolarity::Positive);
Expand Down
3 changes: 2 additions & 1 deletion tests/distilgpt2.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::path::PathBuf;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::{Gpt2Config, GPT2LMHeadModel, Config, LMHeadModel};
use rust_bert::gpt2::gpt2::{Gpt2Config, GPT2LMHeadModel, LMHeadModel};
use rust_bert::Config;

#[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> {
Expand Down
Loading

0 comments on commit f17b0d7

Please sign in to comment.