-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix constraint bug in beam search, clean up tests #5328
Conversation
tests/nn/beam_search_test.py
Outdated
step_function = get_step_function(repeated_ngram_transition_probabilities_1) | ||
self.beam_search.max_steps = 5 | ||
expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]]) | ||
expected_log_probs = np.log( | ||
np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1]) | ||
) | ||
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=step_function, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part fails before the fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of the changes to this file are just refactoring.
backpointer = torch.divide( | ||
restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids a deprecation warning.
for last_token in last_predictions: | ||
log_probs = torch.log(transition_matrix[last_token.item()]) | ||
log_probs_list.append(log_probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that there should be a way to do this without a for loop? Not that it really matters, but something like Tensor.index_select(...)
?
See #5216 (comment) for context. @JohnGiorgi @danieldeutsch