Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tch to 0.15 #443

Merged
merged 4 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ All notable changes to this project will be documented in this file. The format

## [Unreleased]

## Changed
- (BREAKING) Upgraded to `torch` 2.2 (via `tch` 0.15.0).

## [0.22.0] - 2024-01-20
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.

## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering n-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
Expand Down Expand Up @@ -449,4 +452,4 @@ All notable changes to this project will be documented in this file. The format

- Tensor conversion tools from Pytorch to Libtorch format
- DistilBERT model architecture
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.1.1"
tch = "0.14.0"
tch = "0.15.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
Expand All @@ -97,7 +97,6 @@ anyhow = "1"
csv = "1"
criterion = "0.5"
tokio = { version = "1.35", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.14.0"
tempfile = "3"
itertools = "0.12"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett

### Manual installation (recommended)

1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.2`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
Expand Down
4 changes: 0 additions & 4 deletions benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str
}

fn bench_generation(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_text_generation_model();

// Define input
Expand Down
4 changes: 1 addition & 3 deletions benches/squad_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ fn qa_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up QA model
let model = create_qa_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}

// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
Expand Down
4 changes: 1 addition & 3 deletions benches/sst2_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ fn sst2_load_model(iters: u64) -> Duration {
fn bench_sst2(c: &mut Criterion) {
// Set-up classifier
let model = create_sentiment_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}

// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH").expect(
"Please set the \"SST2_PATH\" environment variable pointing to the SST2 dataset folder",
Expand Down
3 changes: 0 additions & 3 deletions benches/summarization_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ fn summarization_load_model(iters: u64) -> Duration {

fn bench_squad(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_summarization_model();

// Define input
Expand Down
4 changes: 0 additions & 4 deletions benches/tensor_operations_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
}

fn bench_tensor_ops(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand([32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand([512, 512], (Kind::Float, Device::cuda_if_available()));

Expand Down
3 changes: 0 additions & 3 deletions benches/token_classification_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ fn create_model() -> TokenClassificationModel {

fn bench_token_classification_predict(c: &mut Criterion) {
// Set-up model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_model();

// Define input
Expand Down
3 changes: 0 additions & 3 deletions benches/translation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ fn translation_load_model(iters: u64) -> Duration {

fn bench_squad(c: &mut Criterion) {
// Set-up translation model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_translation_model();

// Define input
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.2`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
Expand Down
Loading