Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Max out span extractor #5520

Merged
merged 9 commits into from
Jan 5, 2022
Prev Previous commit
Next Next commit
Changes according to coding conventions.
  • Loading branch information
MSLars committed Dec 20, 2021
commit 950504318267d91093c200aef7549723e1ad1c59
1 change: 1 addition & 0 deletions allennlp/modules/span_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from allennlp.modules.span_extractors.bidirectional_endpoint_span_extractor import (
BidirectionalEndpointSpanExtractor,
)
from allennlp.modules.span_extractors.max_pooling_span_extractor import MaxPoolingSpanExtractor
33 changes: 17 additions & 16 deletions allennlp/modules/span_extractors/max_pooling_span_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.nn import util
from allennlp.nn.util import masked_max

Expand Down Expand Up @@ -43,11 +42,11 @@ class MaxPoolingSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""

def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
Expand All @@ -62,11 +61,11 @@ def get_output_dim(self) -> int:
return self._input_dim

def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.FloatTensor:

if sequence_tensor.size(-1) != self._input_dim:
Expand All @@ -84,7 +83,7 @@ def _embed_spans(

if (span_indices[:, :, 0] > span_indices[:, :, 1]).any():
raise IndexError(
f"Span start above span end",
"Span start above span end",
)

# Calculate the maximum sequence length for each element in batch.
Expand All @@ -101,15 +100,17 @@ def _embed_spans(
adopted_span_indices = torch.tensor(span_indices, device=span_indices.device)
epwalsh marked this conversation as resolved.
Show resolved Hide resolved

for b in range(sequence_lengths.shape[0]):
adopted_span_indices[b, :, 1][adopted_span_indices[b, :, 1] >= sequence_lengths[b]] = sequence_lengths[
b] - 1
adopted_span_indices[b, :, 1][adopted_span_indices[b, :, 1] >= sequence_lengths[b]] = (
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
sequence_lengths[b] - 1
)

# Raise Error if span indices were completly masked by sequence mask.
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
# We only adjust span_end to the last valid index, so if span_end is below span_start, both were above the max index:
# We only adjust span_end to the last valid index, so if span_end is below span_start,
# both were above the max index:

if (adopted_span_indices[:, :, 0] > adopted_span_indices[:, :, 1]).any():
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
raise IndexError(
f"Span indices were masked out entirely by sequence mask",
"Span indices were masked out entirely by sequence mask",
)

# span_vals <- (batch x num_spans x max_span_length x dim)
Expand Down
71 changes: 37 additions & 34 deletions tests/modules/span_extractors/max_pooling_span_extractor_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -34,7 +33,7 @@ def test_max_values_extracted(self):
assert extractor.get_output_dim() == 30
assert extractor.get_input_dim() == 30

# We iterate ofer the tensor to compare the span extractors's results
# We iterate over the tensor to compare the span extractors's results
# with the results of python max operation over each dimension for each span and for each batch
# For each batch
for batch, X in enumerate(indices):
Expand All @@ -43,21 +42,23 @@ def test_max_values_extracted(self):

# original features of current tested span
# span_width x embedding dim (30)
span_features_complete = sequence_tensor[batch][span_def[0]: span_def[1] + 1]
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

# comparisson for each dimension
# comparison for each dimension
for i in range(extractor.get_output_dim()):
# get the features for dimension i of current span
features_from_span = span_features_complete[:, i]
real_max_value = max(features_from_span)

extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, f"Error extracting max value for " \
f"batch {batch}, span {indices_ind} on dimension {i}." \
f"expected {real_max_value} " \
f"but got {extrected_max_value} which is " \
f"not the maximum element."
assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
f"not the maximum element."
)

def test_sequence_mask_correct_excluded(self):
# Check if span indices masked out by the sequence mask are ignored when computing
Expand All @@ -70,31 +71,30 @@ def test_sequence_mask_correct_excluded(self):
# define sequence mak
seq_mask = torch.BoolTensor([[True] * 4 + [False] * 2, [True] * 5 + [False] * 1])

span_representations = extractor(sequence_tensor,
indices,
sequence_mask=seq_mask
)
span_representations = extractor(sequence_tensor, indices, sequence_mask=seq_mask)

# After we computed the representations we set values to -inf
# to compute the "real" max-pooling with python's max function.
sequence_tensor[seq_mask == False] = float("-inf")
sequence_tensor[torch.logical_not(seq_mask)] = float("-inf")

# Comparisson is similar to test_max_values_extracted
# Comparison is similar to test_max_values_extracted
for batch, X in enumerate(indices):
for indices_ind, span_def in enumerate(X):

span_features_complete = sequence_tensor[batch][span_def[0]: span_def[1] + 1]
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

for i, _ in enumerate(span_features_complete):
features_from_span = span_features_complete[:, i]
real_max_value = max(features_from_span)
extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, f"Error extracting max value for " \
f"batch {batch}, span {indices_ind} on dimension {i}." \
f"expected {real_max_value} " \
f"but got {extrected_max_value} which is " \
f"not the maximum element."
assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
f"not the maximum element."
)

def test_span_mask_correct_excluded(self):
# All masked out span indices by span_mask should be '0'
Expand All @@ -106,28 +106,31 @@ def test_span_mask_correct_excluded(self):

span_mask = torch.BoolTensor([[True] * 3, [False] * 3])

span_representations = extractor(sequence_tensor,
indices,
span_indices_mask=span_mask,
)
span_representations = extractor(
sequence_tensor,
indices,
span_indices_mask=span_mask,
)

# The span mask masked out all indices in the last batch
# We check whether all soan representations for this batch are '0'
# The span-mask masks out all indices in the last batch
# We check whether all span representations for this batch are '0'
X = indices[-1]
batch = -1
for indices_ind, span_def in enumerate(X):

span_features_complete = sequence_tensor[batch][span_def[0]: span_def[1] + 1]
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1]

for i, _ in enumerate(span_features_complete):
real_max_value = torch.FloatTensor([0.])
real_max_value = torch.FloatTensor([0.0])
extrected_max_value = span_representations[batch, indices_ind, i]

assert real_max_value == extrected_max_value, f"Error extracting max value for " \
f"batch {batch}, span {indices_ind} on dimension {i}." \
f"expected {real_max_value} " \
f"but got {extrected_max_value} which is " \
f"not the maximum element."
assert real_max_value == extrected_max_value, (
f"Error extracting max value for "
f"batch {batch}, span {indices_ind} on dimension {i}."
f"expected {real_max_value} "
f"but got {extrected_max_value} which is "
f"not the maximum element."
)

def test_inconsistent_extractor_dimension_throws_exception(self):

Expand Down