Skip to content

Commit

Permalink
Update gather with sparse_grad set to false (#404)
Browse files Browse the repository at this point in the history
* update `gather` with `sparse_grad` set to false

* Fixed Clippy warnings & updated changelog

* Pin ort version
  • Loading branch information
guillaume-be committed Jul 16, 2023
1 parent 107fb21 commit a655b3c
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased]
## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations

## [0.21.0] - 2023-06-03
## Added
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ regex = "1.6"
cached-path = { version = "0.6", default-features = false, optional = true }
dirs = { version = "4", optional = true }
lazy_static = { version = "1", optional = true }
ort = {version="1.14.8", optional = true, default-features = false, features = ["half"]}
ort = {version="~1.14.8", optional = true, default-features = false, features = ["half"]}
ndarray = {version="0.15", optional = true}

[dev-dependencies]
Expand All @@ -99,4 +99,4 @@ torch-sys = "0.13.0"
tempfile = "3"
itertools = "0.10"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
ort = {version="1.14.8", features = ["load-dynamic"]}
ort = {version="~1.14.8", features = ["load-dynamic"]}
2 changes: 1 addition & 1 deletion src/models/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let output = input_ids.empty_like().to_kind(Kind::Int64);
output
.select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
.copy_(&input_ids.gather(1, &index_eos, false).squeeze());
output
.slice(1, 1, *output.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
Expand Down
6 changes: 3 additions & 3 deletions src/models/deberta/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl DebertaDisentangledSelfAttention {
let c2p_att = c2p_att.gather(
-1,
&self.c2p_dynamic_expand(&c2p_pos, query_layer, &relative_pos),
true,
false,
);
score = score + c2p_att;
}
Expand All @@ -213,15 +213,15 @@ impl DebertaDisentangledSelfAttention {
.gather(
-1,
&self.p2c_dynamic_expand(&p2c_pos, query_layer, key_layer),
true,
false,
)
.transpose(-1, -2);
if query_layer_size[1] != key_layer_size[1] {
let pos_index = relative_pos.select(3, 0).unsqueeze(-1);
p2c_att = p2c_att.gather(
-2,
&self.pos_dynamic_expand(&pos_index, &p2c_att, key_layer),
true,
false,
);
}
score = score + p2c_att;
Expand Down
6 changes: 3 additions & 3 deletions src/models/deberta_v2/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ impl DebertaV2DisentangledSelfAttention {
],
true,
),
true,
false,
);
score = score + c2p_att / scale;
Some(c2p_pos)
Expand Down Expand Up @@ -189,7 +189,7 @@ impl DebertaV2DisentangledSelfAttention {
[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
true,
),
true,
false,
)
.transpose(-1, -2);
score = score + p2c_att / scale;
Expand All @@ -211,7 +211,7 @@ impl DebertaV2DisentangledSelfAttention {
],
true,
),
true,
false,
);
score = score + p2p_att;
}
Expand Down
4 changes: 2 additions & 2 deletions src/models/distilbert/distilbert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ impl DistilBertModel {
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "distilbert";
let embeddings = DistilBertEmbedding::new(p.borrow() / "embeddings", config);
let transformer = Transformer::new(p.borrow() / "transformer", config);
let embeddings = DistilBertEmbedding::new(&p / "embeddings", config);
let transformer = Transformer::new(p / "transformer", config);
DistilBertModel {
embeddings,
transformer,
Expand Down
2 changes: 1 addition & 1 deletion src/models/mbart/mbart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
- 1;
output
.select(1, 0)
.copy_(&input_ids.gather(1, &index_eos, true).squeeze());
.copy_(&input_ids.gather(1, &index_eos, false).squeeze());
output
.slice(1, 1, *output.size().last().unwrap(), 1)
.copy_(&input_ids.slice(1, 0, *output.size().last().unwrap() - 1, 1));
Expand Down
2 changes: 1 addition & 1 deletion src/models/prophetnet/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ impl ProphetNetNgramAttention {
]);

rel_pos_embeddings
.gather(1, &predict_relative_position_buckets, true)
.gather(1, &predict_relative_position_buckets, false)
.view([
self.ngram,
batch_size * self.num_attention_heads,
Expand Down
4 changes: 2 additions & 2 deletions src/models/reformer/attention_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ pub fn reverse_sort(
let expanded_undo_sort_indices = undo_sorted_bucket_idx
.unsqueeze(-1)
.expand(out_vectors.size().as_slice(), true);
let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, true);
let logits = logits.gather(2, undo_sorted_bucket_idx, true);
let out_vectors = out_vectors.gather(2, &expanded_undo_sort_indices, false);
let logits = logits.gather(2, undo_sorted_bucket_idx, false);
(out_vectors, logits)
}

Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/generation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ pub(crate) mod private_generation_utils {
prev_scores.push(
next_token_logits
.log_softmax(-1, next_token_logits.kind())
.gather(1, &next_token.reshape([-1, 1]), true)
.gather(1, &next_token.reshape([-1, 1]), false)
.squeeze()
.masked_fill(&finished_mask, 0),
);
Expand Down
3 changes: 1 addition & 2 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
resources::RemoteResource,
};
use std::ops::Deref;
use tch::kind::Kind::{Bool, Float};
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
Expand Down Expand Up @@ -698,7 +697,7 @@ impl ZeroShotClassificationModel {
.flat_map(|input| {
label_sentences
.iter()
.map(move |label_sentence| (input.deref(), label_sentence.as_str()))
.map(move |label_sentence| (*input, label_sentence.as_str()))
})
.collect::<Vec<(&str, &str)>>();

Expand Down

0 comments on commit a655b3c

Please sign in to comment.