This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Data cleanup and additions from the wikitables branch (#1084)
* Data cleanup and additions from the wikitables branch * Added docs
- Loading branch information
1 parent
6829edb
commit 78dc1df
Showing
10 changed files
with
147 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import logging | ||
from typing import List, Tuple, Dict, Iterable, Generator, Union | ||
from collections import defaultdict | ||
|
||
from overrides import overrides | ||
import numpy | ||
|
||
from allennlp.data.fields import MetadataField | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.data_iterator import DataIterator | ||
from allennlp.data.iterators.bucket_iterator import BucketIterator | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
@DataIterator.register("epoch_tracking_bucket") | ||
class EpochTrackingBucketIterator(BucketIterator): | ||
""" | ||
This is essentially a :class:`allennlp.data.iterators.BucketIterator` with just one difference. | ||
It keeps track of the epoch number, and adds that as an additional meta field to each instance. | ||
That way, ``Model.forward`` will have access to this information. We do this by keeping track of | ||
epochs globally, and incrementing them whenever the iterator is called. However, the iterator is | ||
called both for training and validation sets. So, we keep a dict of epoch numbers, one key per | ||
dataset. | ||
Parameters | ||
---------- | ||
See :class:`BucketIterator`. | ||
""" | ||
def __init__(self, | ||
sorting_keys: List[Tuple[str, str]], | ||
padding_noise: float = 0.1, | ||
biggest_batch_first: bool = False, | ||
batch_size: int = 32, | ||
instances_per_epoch: int = None, | ||
max_instances_in_memory: int = None) -> None: | ||
super(EpochTrackingBucketIterator, self).__init__(sorting_keys=sorting_keys, | ||
padding_noise=padding_noise, | ||
biggest_batch_first=biggest_batch_first, | ||
batch_size=batch_size, | ||
instances_per_epoch=instances_per_epoch, | ||
max_instances_in_memory=max_instances_in_memory) | ||
# Epoch number value per dataset. | ||
self._global_epoch_nums: Dict[int, int] = defaultdict(int) | ||
|
||
@overrides | ||
def __call__(self, | ||
instances: Iterable[Instance], | ||
num_epochs: int = None, | ||
shuffle: bool = True, | ||
cuda_device: int = -1, | ||
for_training: bool = True) -> Generator[Dict[str, Union[numpy.ndarray, | ||
Dict[str, numpy.ndarray]]], | ||
None, None]: | ||
""" | ||
See ``DataIterator.__call__`` for parameters. | ||
""" | ||
dataset_id = id(instances) | ||
if num_epochs is None: | ||
while True: | ||
self._add_epoch_num_to_instances(instances, dataset_id) | ||
yield from self._yield_one_epoch(instances, shuffle, cuda_device, for_training) | ||
self._global_epoch_nums[dataset_id] += 1 | ||
else: | ||
for _ in range(num_epochs): | ||
self._add_epoch_num_to_instances(instances, dataset_id) | ||
yield from self._yield_one_epoch(instances, shuffle, cuda_device, for_training) | ||
self._global_epoch_nums[dataset_id] += 1 | ||
|
||
def _add_epoch_num_to_instances(self, | ||
instances: Iterable[Instance], | ||
dataset_id: int) -> None: | ||
for instance in instances: | ||
# TODO(pradeep): Mypy complains here most probably because ``fields`` is typed as a | ||
# ``Mapping``, and not a ``Dict``. Ignoring this for now, but the type of fields | ||
# probably needs to be changed. | ||
instance.fields["epoch_num"] = MetadataField(self._global_epoch_nums[dataset_id]) #type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
tests/data/iterators/epoch_tracking_bucket_iterator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from allennlp.data.iterators import EpochTrackingBucketIterator | ||
from tests.data.iterators.basic_iterator_test import IteratorTest | ||
|
||
|
||
class EpochTrackingBucketIteratorTest(IteratorTest): | ||
def setUp(self): | ||
# The super class creates a self.instances field and populates it with some instances with | ||
# TextFields. | ||
super(EpochTrackingBucketIteratorTest, self).setUp() | ||
self.iterator = EpochTrackingBucketIterator(sorting_keys=[["text", "num_tokens"]]) | ||
# We'll add more to create a second dataset. | ||
self.more_instances = [ | ||
self.create_instance(["this", "is", "a", "sentence"]), | ||
self.create_instance(["this", "is", "in", "the", "second", "dataset"]), | ||
self.create_instance(["so", "is", "this", "one"]) | ||
] | ||
|
||
def test_iterator_tracks_epochs_per_dataset(self): | ||
generated_dataset1 = list(self.iterator(self.instances, num_epochs=2)) | ||
generated_dataset2 = list(self.iterator(self.more_instances, num_epochs=2)) | ||
|
||
# First dataset has five sentences. See ``IteratorTest.setUp`` | ||
assert generated_dataset1[0]["epoch_num"] == [0, 0, 0, 0, 0] | ||
assert generated_dataset1[1]["epoch_num"] == [1, 1, 1, 1, 1] | ||
# Second dataset has three sentences. | ||
assert generated_dataset2[0]["epoch_num"] == [0, 0, 0] | ||
assert generated_dataset2[1]["epoch_num"] == [1, 1, 1] |