This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix constraint bug in beam search, clean up tests #5328
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Dict, Tuple | ||
from typing import Dict, Tuple, Union | ||
|
||
import numpy as np | ||
import pytest | ||
|
@@ -12,120 +12,98 @@ | |
TopKSampler, | ||
TopPSampler, | ||
GumbelSampler, | ||
SequenceLogProbabilityScorer, | ||
LengthNormalizedSequenceLogProbabilityScorer, | ||
RepeatedNGramBlockingConstraint, | ||
StepFunctionTypeWithTimestep, | ||
StepFunctionTypeNoTimestep, | ||
) | ||
from allennlp.common.params import Params | ||
from allennlp.nn.util import min_value_of_dtype | ||
|
||
|
||
# fmt: off | ||
transition_probabilities = torch.tensor( | ||
[ | ||
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # start token -> jth token | ||
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # 1st token -> jth token | ||
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # 2nd token -> jth token | ||
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # ... | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ... | ||
[0.2, 0.1, 0.2, 0.2, 0.2, 0.3], | ||
] # end token -> jth token | ||
[ # START 1 2 3 4 END | ||
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j | ||
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # 1 -> j | ||
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0], # 2 -> j | ||
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # 3 -> j | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j | ||
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) | ||
] | ||
) | ||
|
||
# A transition matrix that favors shorter sequences over longer ones | ||
short_sequence_transition_probabilities = torch.tensor( | ||
[ | ||
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # start token -> jth token | ||
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1st token -> jth token | ||
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2nd token -> jth token | ||
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # ... | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ... | ||
[0.2, 0.1, 0.2, 0.2, 0.2, 0.3], | ||
] # end token -> jth token | ||
[ # START 1 2 3 4 END | ||
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # START -> j | ||
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1 -> j | ||
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2 -> j | ||
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # 3 -> j | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # 4 -> j | ||
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) | ||
] | ||
) | ||
|
||
# A transition matrix that favors repeated ngrams | ||
repeated_ngram_transition_probabilities = torch.tensor( | ||
[ | ||
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # start token -> jth token | ||
[0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1st token -> jth token | ||
[0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2nd token -> jth token | ||
[0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # ... | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # not used | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], | ||
] # end token -> jth token | ||
repeated_ngram_transition_probabilities_0 = torch.tensor( | ||
[ # START 1 2 3 4 END | ||
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # START -> j | ||
[0.0, 0.0, 0.4, 0.6, 0.0, 1e-9], # 1 -> j | ||
[0.0, 0.0, 0.0, 1.0, 0.0, 1e-9], # 2 -> j | ||
[0.0, 1.0, 0.0, 0.0, 0.0, 1e-9], # 3 -> j | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # 4 -> j (not used) | ||
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # END -> j (doesn't matter) | ||
] | ||
) | ||
|
||
# Another transition matrix that favors repeated ngrams | ||
repeated_ngram_transition_probabilities_1 = torch.tensor( | ||
[ # START 1 2 3 4 END | ||
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0], # START -> j | ||
[0.0, 0.4, 0.3, 0.2, 0.1, 0.1], # 1 -> j | ||
[0.0, 0.0, 0.4, 0.3, 0.2, 0.1], # 2 -> j | ||
[0.0, 0.0, 0.3, 0.4, 0.2, 0.1], # 3 -> j | ||
[0.0, 0.0, 0.2, 0.3, 0.4, 0.1], # 4 -> j | ||
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1], # END -> j (doesn't matter) | ||
] | ||
) | ||
# fmt: on | ||
|
||
log_probabilities = torch.log( | ||
torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]]) | ||
) | ||
|
||
|
||
def take_step_no_timestep( | ||
last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
""" | ||
Take decoding step. | ||
|
||
This is a simple function that defines how probabilities are computed for the | ||
next time step during the beam search. | ||
|
||
We use a simple target vocabulary of size 6. In this vocabulary, index 0 represents | ||
the start token, and index 5 represents the end token. The transition probability | ||
from a state where the last predicted token was token `j` to new token `i` is | ||
given by the `(i, j)` element of the matrix `transition_probabilities`. | ||
""" | ||
log_probs_list = [] | ||
for last_token in last_predictions: | ||
log_probs = torch.log(transition_probabilities[last_token.item()]) | ||
log_probs_list.append(log_probs) | ||
|
||
return torch.stack(log_probs_list), state | ||
|
||
def get_step_function( | ||
transition_matrix: torch.Tensor, with_timestep: bool = False | ||
) -> Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep]: | ||
def _step_function( | ||
last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
log_probs_list = [] | ||
for last_token in last_predictions: | ||
log_probs = torch.log(transition_matrix[last_token.item()]) | ||
log_probs_list.append(log_probs) | ||
Comment on lines
+85
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to me that there should be a way to do this without a for loop? Not that it really matters, but something like |
||
|
||
def take_step_with_timestep( | ||
last_predictions: torch.Tensor, | ||
state: Dict[str, torch.Tensor], | ||
timestep: int, | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
return take_step_no_timestep(last_predictions, state) | ||
return torch.stack(log_probs_list), state | ||
|
||
if not with_timestep: | ||
return _step_function | ||
|
||
def take_short_sequence_step( | ||
last_predictions: torch.Tensor, | ||
state: Dict[str, torch.Tensor], | ||
timestep: int, | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
""" | ||
Take decoding step. | ||
def _step_function_with_timestep( | ||
last_predictions: torch.Tensor, | ||
state: Dict[str, torch.Tensor], | ||
timestep: int, | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
return _step_function(last_predictions, state) | ||
|
||
This method is the same as `take_step_no_timestep` except it uses the | ||
`short_sequence_transition_probabilities` transitions instead of `transition_probabilities` | ||
""" | ||
log_probs_list = [] | ||
for last_token in last_predictions: | ||
log_probs = torch.log(short_sequence_transition_probabilities[last_token.item()]) | ||
log_probs_list.append(log_probs) | ||
return _step_function_with_timestep | ||
|
||
return torch.stack(log_probs_list), state | ||
|
||
|
||
def take_repeated_ngrams_step( | ||
last_predictions: torch.Tensor, | ||
state: Dict[str, torch.Tensor], | ||
timestep: int, | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
""" | ||
Take decoding step. | ||
|
||
This method is the same as `take_step_no_timestep` except it uses the | ||
`short_sequence_transition_probabilities` transitions instead of `transition_probabilities` | ||
""" | ||
log_probs_list = [] | ||
for last_token in last_predictions: | ||
log_probs = torch.log(repeated_ngram_transition_probabilities[last_token.item()]) | ||
log_probs_list.append(log_probs) | ||
|
||
return torch.stack(log_probs_list), state | ||
take_step_no_timestep = get_step_function(transition_probabilities) | ||
take_step_with_timestep = get_step_function(transition_probabilities, with_timestep=True) | ||
take_short_sequence_step = get_step_function(short_sequence_transition_probabilities) | ||
|
||
|
||
class BeamSearchTest(AllenNlpTestCase): | ||
|
@@ -575,11 +553,6 @@ def test_gumbel_sampler(self): | |
assert all([x >= 0 and x < 4 for x in indices[0]]) | ||
assert all([x > 1 and x <= 5 for x in indices[1]]) | ||
|
||
def test_sequence_log_prob_scorer(self): | ||
# SequenceLogProbabilityScorer is the default, so manually setting the | ||
# sequence scorer shouldn't actually change anything | ||
self.beam_search.sequence_scorer = SequenceLogProbabilityScorer() | ||
|
||
def test_length_normalized_sequence_log_prob_scorer(self): | ||
""" | ||
Tests to ensure the sequences are normalized by the correct values. The end token is | ||
|
@@ -708,7 +681,7 @@ def test_repeated_ngram_blocking_constraint_update_state(self): | |
|
||
def test_take_repeated_ngram_step(self): | ||
""" | ||
Tests to ensure the top-k from the short_sequence_transition_probabilities | ||
Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0` | ||
transition matrix is expected. The transitions are: | ||
|
||
- p(1|start) = 1.0 | ||
|
@@ -761,23 +734,19 @@ def test_take_repeated_ngram_step(self): | |
0.4 * 1e-9: [1, 2, 3, 1, end] | ||
0.36 * 1e-9: [1, 3, 1, 3, end] | ||
""" | ||
step_function = get_step_function(repeated_ngram_transition_probabilities_0) | ||
self.beam_search.beam_size = 2 | ||
self.beam_search.max_steps = 5 | ||
expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]]) | ||
expected_log_probs = np.log(np.array([0.36, 0.24])) | ||
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=take_repeated_ngrams_step, | ||
take_step=step_function, | ||
) | ||
|
||
def test_repeated_ngram_blocking_end_to_end(self): | ||
""" | ||
This test checks to make sure the `RepeatedNGramBlockingConstraint` successfully blocks ngrams. | ||
It works by blocking ngrams of different sizes and ensures that the result of beam search | ||
is correctly changed. We rely on the beam search trace for `repeated_ngram_transition_probabilities` | ||
in `test_take_repeated_ngram_step`. | ||
""" | ||
def test_repeated_ngram_blocking_end_to_end_unigrams(self): | ||
step_function = get_step_function(repeated_ngram_transition_probabilities_0) | ||
self.beam_search.beam_size = 2 | ||
|
||
# Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place | ||
|
@@ -788,9 +757,25 @@ def test_repeated_ngram_blocking_end_to_end(self): | |
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=take_repeated_ngrams_step, | ||
take_step=step_function, | ||
) | ||
|
||
step_function = get_step_function(repeated_ngram_transition_probabilities_1) | ||
self.beam_search.max_steps = 5 | ||
expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]]) | ||
expected_log_probs = np.log( | ||
np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1]) | ||
) | ||
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=step_function, | ||
) | ||
|
||
def test_repeated_ngram_blocking_end_to_end_bigrams(self): | ||
step_function = get_step_function(repeated_ngram_transition_probabilities_0) | ||
self.beam_search.beam_size = 2 | ||
|
||
# Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place | ||
self.beam_search.max_steps = 4 | ||
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)] | ||
|
@@ -799,9 +784,13 @@ def test_repeated_ngram_blocking_end_to_end(self): | |
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=take_repeated_ngrams_step, | ||
take_step=step_function, | ||
) | ||
|
||
def test_repeated_ngram_blocking_end_to_end_trigrams(self): | ||
step_function = get_step_function(repeated_ngram_transition_probabilities_0) | ||
self.beam_search.beam_size = 2 | ||
|
||
# Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place | ||
self.beam_search.max_steps = 5 | ||
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)] | ||
|
@@ -810,7 +799,7 @@ def test_repeated_ngram_blocking_end_to_end(self): | |
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=take_repeated_ngrams_step, | ||
take_step=step_function, | ||
) | ||
|
||
def test_repeated_ngram_blocking_end_indices(self): | ||
|
@@ -820,12 +809,13 @@ def test_repeated_ngram_blocking_end_indices(self): | |
""" | ||
# We block unigrams, but 5 (the end symbol) is repeated and it does not mess | ||
# up the sequence's probability | ||
step_function = get_step_function(repeated_ngram_transition_probabilities_0) | ||
self.beam_search.beam_size = 2 | ||
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)] | ||
expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]]) | ||
expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9])) | ||
self._check_results( | ||
expected_top_k=expected_top_k, | ||
expected_log_probs=expected_log_probs, | ||
take_step=take_repeated_ngrams_step, | ||
take_step=step_function, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids a deprecation warning.