Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

fix constraint bug in beam search, clean up tests #5328

Merged
merged 4 commits into from
Aug 3, 2021
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Jul 22, 2021

Comment on lines 763 to 773
step_function = get_step_function(repeated_ngram_transition_probabilities_1)
self.beam_search.max_steps = 5
expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]])
expected_log_probs = np.log(
np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])
)
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=step_function,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part fails before the fix.

Copy link
Member Author

@epwalsh epwalsh Jul 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the changes to this file are just refactoring.

Comment on lines +1088 to +1090
backpointer = torch.divide(
restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids a deprecation warning.

@epwalsh epwalsh requested a review from dirkgr July 22, 2021 20:15
Comment on lines +85 to +87
for last_token in last_predictions:
log_probs = torch.log(transition_matrix[last_token.item()])
log_probs_list.append(log_probs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that there should be a way to do this without a for loop? Not that it really matters, but something like Tensor.index_select(...)?

@epwalsh epwalsh merged commit b72bbfc into main Aug 3, 2021
@epwalsh epwalsh deleted the beam-search-fixes branch August 3, 2021 16:55
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants