From a655b3c7c40eb3a6bc1db66e85fbda5c63514ec8 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 16 Jul 2023 09:46:46 +0100 Subject: [PATCH] Update `gather` with `sparse_grad` set to false (#404) * update `gather` with `sparse_grad` set to false * Fixed Clippy warnings & updated changelog * Pin ort version --- CHANGELOG.md | 1 + Cargo.toml | 4 ++-- src/models/bart/bart_model.rs | 2 +- src/models/deberta/attention.rs | 6 +++--- src/models/deberta_v2/attention.rs | 6 +++--- src/models/distilbert/distilbert_model.rs | 4 ++-- src/models/mbart/mbart_model.rs | 2 +- src/models/prophetnet/attention.rs | 2 +- src/models/reformer/attention_utils.rs | 4 ++-- src/pipelines/generation_utils.rs | 2 +- src/pipelines/zero_shot_classification.rs | 3 +-- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62de73e51..bcb8a7098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. The format ## [Unreleased] ## Fixed - (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences). +- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations ## [0.21.0] - 2023-06-03 ## Added diff --git a/Cargo.toml b/Cargo.toml index fd61596b5..71520c4b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,7 +87,7 @@ regex = "1.6" cached-path = { version = "0.6", default-features = false, optional = true } dirs = { version = "4", optional = true } lazy_static = { version = "1", optional = true } -ort = {version="1.14.8", optional = true, default-features = false, features = ["half"]} +ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]} ndarray = {version="0.15", optional = true} [dev-dependencies] @@ -99,4 +99,4 @@ torch-sys = "0.13.0" tempfile = "3" itertools = "0.10" tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } -ort = {version="1.14.8", features = ["load-dynamic"]} \ No newline at end of file +ort = {version="~1.14.8", features = ["load-dynamic"]} \ No newline at end of file diff --git a/src/models/bart/bart_model.rs b/src/models/bart/bart_model.rs index 64dc89594..62246fac3 100644 --- a/src/models/bart/bart_model.rs +++ b/src/models/bart/bart_model.rs @@ -357,7 +357,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { let output = input_ids.empty_like().to_kind(Kind::Int64); output .select(1, 0) - .copy_(&input_ids.gather(1, &index_eos, true).squeeze()); + .copy_(&input_ids.gather(1, &index_eos, false).squeeze()); output .slice(1, 1, *output.size().last().unwrap(), 1) .copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1)); diff --git a/src/models/deberta/attention.rs b/src/models/deberta/attention.rs index 948b01fde..0422d6c92 100644 --- a/src/models/deberta/attention.rs +++ b/src/models/deberta/attention.rs @@ -192,7 +192,7 @@ impl DebertaDisentangledSelfAttention { let c2p_att = c2p_att.gather( -1, &self.c2p_dynamic_expand(&c2p_pos, query_layer, &relative_pos), - true, + false, ); score = score + c2p_att; } @@ -213,7 +213,7 @@ impl DebertaDisentangledSelfAttention { .gather( -1, &self.p2c_dynamic_expand(&p2c_pos, query_layer, key_layer), - true, + false, ) .transpose(-1, -2); if query_layer_size[1] != key_layer_size[1] { @@ -221,7 +221,7 @@ impl DebertaDisentangledSelfAttention { p2c_att = p2c_att.gather( -2, &self.pos_dynamic_expand(&pos_index, &p2c_att, key_layer), - true, + false, ); } score = score + p2c_att; diff --git a/src/models/deberta_v2/attention.rs b/src/models/deberta_v2/attention.rs index 4cc431f96..af1e168cc 100644 --- a/src/models/deberta_v2/attention.rs +++ b/src/models/deberta_v2/attention.rs @@ -156,7 +156,7 @@ impl DebertaV2DisentangledSelfAttention { ], true, ), - true, + false, ); score = score + c2p_att / scale; Some(c2p_pos) @@ -189,7 +189,7 @@ impl DebertaV2DisentangledSelfAttention { [query_layer.size()[0], key_layer_size[1], key_layer_size[1]], true, ), - true, + false, ) .transpose(-1, -2); score = score + p2c_att / scale; @@ -211,7 +211,7 @@ impl DebertaV2DisentangledSelfAttention { ], true, ), - true, + false, ); score = score + p2p_att; } diff --git a/src/models/distilbert/distilbert_model.rs b/src/models/distilbert/distilbert_model.rs index daad2c84f..20ccbf40f 100644 --- a/src/models/distilbert/distilbert_model.rs +++ b/src/models/distilbert/distilbert_model.rs @@ -192,8 +192,8 @@ impl DistilBertModel { P: Borrow>, { let p = p.borrow() / "distilbert"; - let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config); - let transformer = Transformer::new(p.borrow() / "transformer", config); + let embeddings = DistilBertEmbedding::new(&p / "embeddings", config); + let transformer = Transformer::new(p / "transformer", config); DistilBertModel { embeddings, transformer, diff --git a/src/models/mbart/mbart_model.rs b/src/models/mbart/mbart_model.rs index e9eaeffb6..5e8a01623 100644 --- a/src/models/mbart/mbart_model.rs +++ b/src/models/mbart/mbart_model.rs @@ -162,7 +162,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { - 1; output .select(1, 0) - .copy_(&input_ids.gather(1, &index_eos, true).squeeze()); + .copy_(&input_ids.gather(1, &index_eos, false).squeeze()); output .slice(1, 1, *output.size().last().unwrap(), 1) .copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1)); diff --git a/src/models/prophetnet/attention.rs b/src/models/prophetnet/attention.rs index d381129b0..ece63dfe2 100644 --- a/src/models/prophetnet/attention.rs +++ b/src/models/prophetnet/attention.rs @@ -712,7 +712,7 @@ impl ProphetNetNgramAttention { ]); rel_pos_embeddings - .gather(1, &predict_relative_position_buckets, true) + .gather(1, &predict_relative_position_buckets, false) .view([ self.ngram, batch_size * self.num_attention_heads, diff --git a/src/models/reformer/attention_utils.rs b/src/models/reformer/attention_utils.rs index f5101b39b..3ec0b8bf3 100644 --- a/src/models/reformer/attention_utils.rs +++ b/src/models/reformer/attention_utils.rs @@ -173,8 +173,8 @@ pub fn reverse_sort( let expanded_undo_sort_indices = undo_sorted_bucket_idx .unsqueeze(-1) .expand(out_vectors.size().as_slice(), true); - let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, true); - let logits = logits.gather(2, undo_sorted_bucket_idx, true); + let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, false); + let logits = logits.gather(2, undo_sorted_bucket_idx, false); (out_vectors, logits) } diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 8d901e85f..8869dec17 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -964,7 +964,7 @@ pub(crate) mod private_generation_utils { prev_scores.push( next_token_logits .log_softmax(-1, next_token_logits.kind()) - .gather(1, &next_token.reshape([-1, 1]), true) + .gather(1, &next_token.reshape([-1, 1]), false) .squeeze() .masked_fill(&finished_mask, 0), ); diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index c81b788ef..ea0cc327a 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -121,7 +121,6 @@ use crate::{ bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources}, resources::RemoteResource, }; -use std::ops::Deref; use tch::kind::Kind::{Bool, Float}; use tch::nn::VarStore; use tch::{no_grad, Device, Kind, Tensor}; @@ -698,7 +697,7 @@ impl ZeroShotClassificationModel { .flat_map(|input| { label_sentences .iter() - .map(move |label_sentence| (input.deref(), label_sentence.as_str())) + .map(move |label_sentence| (*input, label_sentence.as_str())) }) .collect::>();