Skip to content

Commit

Permalink
Merge pull request #16 from leehosu01/hotfix/mixed_float16
Browse files Browse the repository at this point in the history
Update model.py, mixed_float16 feature
  • Loading branch information
rishigami committed Jan 23, 2022
2 parents 8986ca7 + 137ff4b commit bd4d9f1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion swintransformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def call(self, x, mask=None):
if mask is not None:
nW = mask.get_shape()[0] # tf.shape(mask)[0]
attn = tf.reshape(attn, shape=[-1, nW, self.num_heads, N, N]) + tf.cast(
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32)
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), attn.dtype)
attn = tf.reshape(attn, shape=[-1, self.num_heads, N, N])
attn = tf.nn.softmax(attn, axis=-1)
else:
Expand Down

0 comments on commit bd4d9f1

Please sign in to comment.