Skip to content

Commit

Permalink
don't use subtraction of boolean tensors, use xor (#420)
Browse files Browse the repository at this point in the history
* don't use subtraction of boolean tensors, use xor

This code fails to run without the change

* rustfmt

---------

Co-authored-by: Charles Samuels <ks@ks.ax>
Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
  • Loading branch information
3 people committed Sep 30, 2023
1 parent 9575902 commit 98c7905
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/models/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,11 @@ impl<T: BertEmbedding> BertModel<T> {
train,
)?;

let extended_attention_mask: Tensor =
((extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0)
.to_kind(embedding_output.kind());
let extended_attention_mask: Tensor = ((extended_attention_mask
.ones_like()
.bitwise_xor_tensor(&extended_attention_mask))
* -10000.0)
.to_kind(embedding_output.kind());

let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
Expand Down

0 comments on commit 98c7905

Please sign in to comment.