Skip to content

Commit

Permalink
convert relative_position_index to int64
Browse files Browse the repository at this point in the history
Convert dtype of relative_position_index to np.int64
  • Loading branch information
rishigami committed Aug 16, 2021
1 parent f9d7120 commit 47307c3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion swintransformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def build(self, input_shape):
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
relative_position_index = relative_coords.sum(-1).astype(np.int64)
self.relative_position_index = tf.Variable(initial_value=tf.convert_to_tensor(
relative_position_index), trainable=False, name=f'{self.prefix}/attn/relative_position_index')
self.built = True
Expand Down

0 comments on commit 47307c3

Please sign in to comment.