Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Bugfix: initializing all tensors and parameters of the `ConditionalRa…
Browse files Browse the repository at this point in the history
…ndomField` model on the proper device (#5335)

* Bugfix: initializing all tensors and parameters of the `ConditionalRandomField` in the proper device

* Using `torch.full()` instead of torch.Tensor().fill_()

* Updated CHANGELOG.md

Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
Co-authored-by: Pete <petew@allenai.org>
Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
4 people authored Aug 18, 2021
1 parent d45a2da commit bffdbfd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed a bug in `ConditionalRandomField`: `transitions` and `tag_sequence` tensors were not initialized on the desired device causing high CPU usage (see https://github.com/allenai/allennlp/issues/2884)
- Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`.
- Fixed broken links in `allennlp.nn.initializers` docs.
- Fixed bug in `BeamSearch` where `last_backpointers` was not being passed to any `Constraint`s.
Expand Down
10 changes: 5 additions & 5 deletions allennlp/modules/conditional_random_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ def __init__(
self.num_tags = num_tags

# transitions[i, j] is the logit for transitioning from state i to state j.
self.transitions = torch.nn.Parameter(torch.Tensor(num_tags, num_tags))
self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags))

# _constraint_mask indicates valid transitions (based on supplied constraints).
# Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2)
if constraints is None:
# All transitions are valid.
constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.0)
constraint_mask = torch.full((num_tags + 2, num_tags + 2), 1.0)
else:
constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(0.0)
constraint_mask = torch.full((num_tags + 2, num_tags + 2), 0.0)
for i, j in constraints:
constraint_mask[i, j] = 1.0

Expand Down Expand Up @@ -364,7 +364,7 @@ def viterbi_tags(
# Augment transitions matrix with start and end transitions
start_tag = num_tags
end_tag = num_tags + 1
transitions = torch.Tensor(num_tags + 2, num_tags + 2).fill_(-10000.0)
transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device)

# Apply transition constraints
constrained_transitions = self.transitions * self._constraint_mask[
Expand Down Expand Up @@ -393,7 +393,7 @@ def viterbi_tags(

best_paths = []
# Pad the max sequence length by 2 to account for start_tag + end_tag.
tag_sequence = torch.Tensor(max_seq_length + 2, num_tags + 2)
tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device)

for prediction, prediction_mask in zip(logits, mask):
mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze()
Expand Down

0 comments on commit bffdbfd

Please sign in to comment.