diff --git a/CHANGELOG.md b/CHANGELOG.md index 58577baac09..7c5588c110e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it - Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable - Added a way for AllenNLP Tango to read and write datasets lazily. +- Added a way to remix datasets flexibly - Added `from_pretrained_transformer_and_instances` constructor to `Vocabulary` ### Fixed @@ -43,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `ConfigurationError` is now pickleable. - Multitask models now support `TextFieldTensor` in heads, not just in the backbone. - Fixed the signature of `ScaledDotProductAttention` to match the other `Attention` classes +- Fixed the way names are applied to Tango `Step` instances. ### Changed diff --git a/allennlp/common/det_hash.py b/allennlp/common/det_hash.py index 958e5190d51..dd68090efb4 100644 --- a/allennlp/common/det_hash.py +++ b/allennlp/common/det_hash.py @@ -1,6 +1,7 @@ +import collections import hashlib import io -from typing import Any +from typing import Any, MutableMapping import base58 import dill @@ -13,6 +14,9 @@ def det_hash_object(self) -> Any: representation. Sometimes you want to take control over what goes into that hash. In that case, implement this method. `det_hash()` will pickle the result of this method instead of the object itself. + + If you return `None`, `det_hash()` falls back to the original behavior and pickles + the object. """ raise NotImplementedError() @@ -38,10 +42,48 @@ def det_hash_object(self) -> Any: return self._det_hash_object +class DetHashWithVersion(CustomDetHash): + """ + Add this class as a mixing base class to make sure your class's det_hash can be modified + by altering a static `VERSION` member of your class. + """ + + VERSION = None + + def det_hash_object(self) -> Any: + if self.VERSION is not None: + return self.VERSION, self + else: + return None + + class _DetHashPickler(dill.Pickler): + def __init__(self, buffer: io.BytesIO): + super().__init__(buffer) + + # We keep track of how deeply we are nesting the pickling of an object. + # If a class returns `self` as part of `det_hash_object()`, it causes an + # infinite recursion, because we try to pickle the `det_hash_object()`, which + # contains `self`, which returns a `det_hash_object()`, etc. + # So we keep track of how many times recursively we are trying to pickle the + # same object. We only call `det_hash_object()` the first time. We assume that + # if `det_hash_object()` returns `self` in any way, we want the second time + # to just pickle the object as normal. `DetHashWithVersion` takes advantage + # of this ability. + self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter() + + def save(self, obj, save_persistent_id=True): + self.recursively_pickled_ids[id(obj)] += 1 + super().save(obj, save_persistent_id) + self.recursively_pickled_ids[id(obj)] -= 1 + def persistent_id(self, obj: Any) -> Any: - if isinstance(obj, CustomDetHash): - return obj.__class__.__qualname__, obj.det_hash_object() + if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1: + det_hash_object = obj.det_hash_object() + if det_hash_object is not None: + return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object + else: + return None elif isinstance(obj, type): return obj.__module__, obj.__qualname__ else: diff --git a/allennlp/common/sequences.py b/allennlp/common/sequences.py new file mode 100644 index 00000000000..855638b4ae1 --- /dev/null +++ b/allennlp/common/sequences.py @@ -0,0 +1,85 @@ +import bisect +import random +from collections import abc +from typing import Sequence, Optional, Union + + +class ShuffledSequence(abc.Sequence): + """ + Produces a shuffled view of a sequence, such as a list. + + This assumes that the inner sequence never changes. If it does, the results + are undefined. + """ + + def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None): + self.inner = inner_sequence + self.indices: Sequence[int] + if indices is None: + self.indices = list(range(len(inner_sequence))) + random.shuffle(self.indices) + else: + self.indices = indices + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, i: Union[int, slice]): + if isinstance(i, int): + return self.inner[self.indices[i]] + else: + return ShuffledSequence(self.inner, self.indices[i]) + + def __contains__(self, item) -> bool: + for i in self.indices: + if self.inner[i] == item: + return True + return False + + +class SlicedSequence(ShuffledSequence): + """ + Produces a sequence that's a slice into another sequence, without copying the elements. + + This assumes that the inner sequence never changes. If it does, the results + are undefined. + """ + + def __init__(self, inner_sequence: Sequence, s: slice): + super().__init__(inner_sequence, range(*s.indices(len(inner_sequence)))) + + +class ConcatenatedSequence(abc.Sequence): + """ + Produces a sequence that's the concatenation of multiple other sequences, without + copying the elements. + + This assumes that the inner sequence never changes. If it does, the results + are undefined. + """ + + def __init__(self, *sequences: Sequence): + self.sequences = sequences + self.cumulative_sequence_lengths = [0] + for sequence in sequences: + self.cumulative_sequence_lengths.append( + self.cumulative_sequence_lengths[-1] + len(sequence) + ) + + def __len__(self): + return self.cumulative_sequence_lengths[-1] + + def __getitem__(self, i: Union[int, slice]): + if isinstance(i, int): + if i < 0: + i += len(self) + if i < 0 or i >= len(self): + raise IndexError("list index out of range") + sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1 + i -= self.cumulative_sequence_lengths[sequence_index] + return self.sequences[sequence_index][i] + else: + return SlicedSequence(self, i) + + def __contains__(self, item) -> bool: + return any(s.__contains__(item) for s in self.sequences) diff --git a/allennlp/common/sqlite_sparse_sequence.py b/allennlp/common/sqlite_sparse_sequence.py index 634b5393a3a..7a45d68258d 100644 --- a/allennlp/common/sqlite_sparse_sequence.py +++ b/allennlp/common/sqlite_sparse_sequence.py @@ -2,10 +2,9 @@ import shutil from os import PathLike from typing import MutableSequence, Any, Union, Iterable - from sqlitedict import SqliteDict -from allennlp.tango.dataloader import ShuffledSequence +from allennlp.common.sequences import SlicedSequence class SqliteSparseSequence(MutableSequence[Any]): @@ -28,7 +27,7 @@ def __getitem__(self, i: Union[int, slice]) -> Any: else: return None elif isinstance(i, slice): - return ShuffledSequence(self, range(*i.indices(len(self)))) + return SlicedSequence(self, i) else: raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}") diff --git a/allennlp/tango/dataloader.py b/allennlp/tango/dataloader.py index 59faffa5acd..bcf3b99db75 100644 --- a/allennlp/tango/dataloader.py +++ b/allennlp/tango/dataloader.py @@ -4,10 +4,8 @@ """ import logging -import random -from collections import abc from math import floor, ceil -from typing import Optional, Iterator, Sequence, Union, Dict, Any +from typing import Optional, Iterator, Sequence, Dict, Any import more_itertools import torch @@ -22,6 +20,7 @@ Vocabulary, ) from allennlp.nn.util import move_to_device +from allennlp.common.sequences import ShuffledSequence class TangoDataLoader(Registrable): @@ -86,36 +85,6 @@ def set_target_device(self, device: torch.device) -> None: self.target_device = device -class ShuffledSequence(abc.Sequence): - """ - Produces a shuffled view of a sequence, such as a list. - - This assumes that the inner sequence never changes. If it does, the results - are undefined. - """ - - def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None): - self.inner = inner_sequence - self.indices: Sequence[int] - if indices is None: - self.indices = list(range(len(inner_sequence))) - random.shuffle(self.indices) - else: - self.indices = indices - - def __len__(self) -> int: - return len(self.inner) - - def __getitem__(self, i: Union[int, slice]): - if isinstance(i, int): - return self.inner[self.indices[i]] - else: - return ShuffledSequence(self.inner, self.indices[i]) - - def __contains__(self, item) -> bool: - return self.inner.__contains__(item) - - @TangoDataLoader.register("batch_size") class BatchSizeDataLoader(TangoDataLoader): """A data loader that turns instances into batches with a constant number of instances diff --git a/allennlp/tango/dataset.py b/allennlp/tango/dataset.py index 6abf6cdf979..0d5478dfe26 100644 --- a/allennlp/tango/dataset.py +++ b/allennlp/tango/dataset.py @@ -4,11 +4,14 @@ """ import itertools +import random +import re from dataclasses import dataclass, field from typing import Mapping, Any, Optional, Sequence, Dict from allennlp.data import Vocabulary, DatasetReader, Instance from allennlp.tango.step import Step +from allennlp.common.sequences import SlicedSequence, ConcatenatedSequence, ShuffledSequence from tqdm import tqdm @@ -39,9 +42,9 @@ def __len__(self) -> int: @Step.register("dataset_reader_adapter") class DatasetReaderAdapterStep(Step): """ - This step creates an `AllenNlpDataset` from old-school dataset readers. If you're + This step creates an `DatasetDict` from old-school dataset readers. If you're tempted to write a new `DatasetReader`, and then use this step with it, don't. - Just write a `Step` that creates the `AllenNlpDataset` you need directly. + Just write a `Step` that creates the `DatasetDict` you need directly. """ DETERMINISTIC = True # We're giving the dataset readers some credit here. @@ -72,3 +75,68 @@ def run(self, reader: DatasetReader, splits: Dict[str, str]) -> DatasetDict: # instance.index_fields(vocab) return DatasetDict(splits=instances_map, vocab=vocab) + + +@Step.register("dataset_remix") +class DatasetRemixStep(Step): + """ + This step can remix splits in a dataset into new splits. + """ + + DETERMINISTIC = True + CACHEABLE = False # This is so fast it's not worth caching. + VERSION = "001" + + def run( # type: ignore + self, + input: DatasetDict, + new_splits: Dict[str, str], + keep_old_splits: bool = True, + shuffle_before: bool = False, + shuffle_after: bool = False, + random_seed: int = 1532637578, + ) -> DatasetDict: + random.seed(random_seed) + + if shuffle_before: + input_splits: Mapping[str, Sequence[Any]] = { + split_name: ShuffledSequence(split_instances) + for split_name, split_instances in input.splits.items() + } + else: + input_splits = input.splits + + def get_slice(split_name: str) -> Sequence[Any]: + slice_match = re.match(r"(.*)\[([0123456789:]*)]", split_name) + if slice_match is None: + return input[split_name] + else: + split_name = slice_match[1] + slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")] + return SlicedSequence(input[split_name], slice(*slice_args)) + + def parse_split_spec(split_spec: str): + parts = [get_slice(name.strip()) for name in split_spec.split("+")] + if len(parts) == 1: + return parts[0] + else: + return ConcatenatedSequence(*parts) + + if keep_old_splits: + result = dict(input_splits.items()) + else: + result = {} + result.update( + { + new_split_name: parse_split_spec(new_split_spec) + for new_split_name, new_split_spec in new_splits.items() + } + ) + + if shuffle_after: + result = { + split_name: ShuffledSequence(split_instances) + for split_name, split_instances in result.items() + } + + return DatasetDict(vocab=input.vocab, metadata=input.metadata, splits=result) diff --git a/allennlp/tango/hf_dataset.py b/allennlp/tango/hf_dataset.py index 67bed8715d4..520d1239ea9 100644 --- a/allennlp/tango/hf_dataset.py +++ b/allennlp/tango/hf_dataset.py @@ -10,7 +10,7 @@ @Step.register("hf_dataset") class HuggingfaceDataset(Step): - """This steps reads a huggingface dataset and returns it in `AllenNlpDataset` format.""" + """This steps reads a huggingface dataset and returns it in `DatasetDict` format.""" DETERMINISTIC = True VERSION = "001" diff --git a/allennlp/tango/step.py b/allennlp/tango/step.py index 225bf5310ff..6f267309de3 100644 --- a/allennlp/tango/step.py +++ b/allennlp/tango/step.py @@ -301,13 +301,8 @@ def __init__( self.unique_id_cache: Optional[str] = None if step_name is None: self.name = self.unique_id() - self.only_if_needed = False else: self.name = step_name - self.only_if_needed = True - - if only_if_needed is not None: - self.only_if_needed = only_if_needed if cache_results is True: if not self.CACHEABLE: @@ -342,9 +337,16 @@ def __init__( assert False, "Step.DETERMINISTIC or step.CACHEABLE are set to an invalid value." else: raise ConfigurationError( - f"Step {step_name}'s cache_results parameter is set to an invalid value." + f"Step {self.name}'s cache_results parameter is set to an invalid value." ) + if step_name is None: + self.only_if_needed = True + else: + self.only_if_needed = not self.cache_results + if only_if_needed is not None: + self.only_if_needed = only_if_needed + self.work_dir_for_run: Optional[ PathLike ] = None # This is set only while the run() method runs. @@ -356,6 +358,7 @@ def from_params( constructor_to_call: Callable[..., "Step"] = None, constructor_to_inspect: Union[Callable[..., "Step"], Callable[["Step"], None]] = None, existing_steps: Optional[Dict[str, "Step"]] = None, + step_name: Optional[str] = None, **extras, ) -> "Step": # Why do we need a custom from_params? Step classes have a run() method that takes all the @@ -445,7 +448,7 @@ def from_params( else: params.assert_empty(subclass.__name__) - return subclass(**kwargs) + return subclass(step_name=step_name, **kwargs) @abstractmethod def run(self, **kwargs) -> T: @@ -456,7 +459,7 @@ def _run_with_work_dir(self, cache: StepCache, **kwargs) -> T: if self.work_dir_for_run is not None: raise ValueError("You can only run a Step's run() method once at a time.") - logger.info("Starting run for step %s of type %s", self.name, self.__class__) + logger.info("Starting run for step %s of type %s", self.name, self.__class__.__name__) if self.DETERMINISTIC: random.seed(784507111) @@ -561,7 +564,7 @@ def det_hash_object(self) -> Any: def unique_id(self) -> str: """Returns the unique ID for this step. - Unique IDs are of the shape `$class_name-$version-$hash`, where the hash is the has of the + Unique IDs are of the shape `$class_name-$version-$hash`, where the hash is the hash of the inputs for deterministic steps, and a random string of characters for non-deterministic ones.""" if self.unique_id_cache is None: self.unique_id_cache = self.__class__.__name__ @@ -679,7 +682,7 @@ def step_graph_from_params(params: Dict[str, Params]) -> Dict[str, Step]: step_params_backup = copy.deepcopy(step_params) try: parsed_steps[step_name] = Step.from_params( - step_params, existing_steps=parsed_steps, extras={"step_name": step_name} + step_params, existing_steps=parsed_steps, step_name=step_name ) steps_parsed += 1 except _RefStep.MissingStepError: diff --git a/tests/common/det_hash_test.py b/tests/common/det_hash_test.py new file mode 100644 index 00000000000..e22a7092fcb --- /dev/null +++ b/tests/common/det_hash_test.py @@ -0,0 +1,57 @@ +from allennlp.common.det_hash import det_hash, DetHashWithVersion + + +def test_normal_det_hash(): + class C: + VERSION = 1 + + def __init__(self, x: int): + self.x = x + + c1_1 = C(10) + c2_1 = C(10) + c3_1 = C(20) + assert det_hash(c1_1) == det_hash(c2_1) + assert det_hash(c3_1) != det_hash(c2_1) + + class C: + VERSION = 2 + + def __init__(self, x: int): + self.x = x + + c1_2 = C(10) + c2_2 = C(10) + c3_2 = C(20) + assert det_hash(c1_2) == det_hash(c2_2) + assert det_hash(c3_2) != det_hash(c2_2) + assert det_hash(c1_2) == det_hash(c1_1) # because the version isn't taken into account + assert det_hash(c3_2) == det_hash(c3_1) # because the version isn't taken into account + + +def test_versioned_det_hash(): + class C(DetHashWithVersion): + VERSION = 1 + + def __init__(self, x: int): + self.x = x + + c1_1 = C(10) + c2_1 = C(10) + c3_1 = C(20) + assert det_hash(c1_1) == det_hash(c2_1) + assert det_hash(c3_1) != det_hash(c2_1) + + class C(DetHashWithVersion): + VERSION = 2 + + def __init__(self, x: int): + self.x = x + + c1_2 = C(10) + c2_2 = C(10) + c3_2 = C(20) + assert det_hash(c1_2) == det_hash(c2_2) + assert det_hash(c3_2) != det_hash(c2_2) + assert det_hash(c1_2) != det_hash(c1_1) # because the version is taken into account + assert det_hash(c3_2) != det_hash(c3_1) # because the version is taken into account diff --git a/tests/common/sequences_test.py b/tests/common/sequences_test.py new file mode 100644 index 00000000000..bcab6389bdc --- /dev/null +++ b/tests/common/sequences_test.py @@ -0,0 +1,57 @@ +import pytest +from allennlp.common.sequences import ConcatenatedSequence + + +def assert_equal_including_exceptions(expected_fn, actual_fn): + try: + expected = expected_fn() + except Exception as e: + with pytest.raises(e.__class__): + actual_fn() + else: + assert expected == actual_fn() + + +def test_concatenated_sequence(): + l1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + l2 = ConcatenatedSequence([0, 1], [], [2, 3, 4], [5, 6, 7, 8, 9], []) + + # __len__() + assert len(l1) == len(l2) + + # index() + for item in l1 + [999]: + # no indices + assert_equal_including_exceptions(lambda: l1.index(item), lambda: l2.index(item)) + + # only start index + for index in range(-15, 15): + assert_equal_including_exceptions( + lambda: l1.index(item, index), lambda: l2.index(item, index) + ) + + # start and stop index + for start_index in range(-15, 15): + for end_index in range(-15, 15): + assert_equal_including_exceptions( + lambda: l1.index(item, start_index, end_index), + lambda: l2.index(item, start_index, end_index), + ) + + # __getitem__() + for index in range(-15, 15): + assert_equal_including_exceptions(lambda: l1[index], lambda: l2[index]) + + for start_index in range(-15, 15): + for end_index in range(-15, 15): + assert_equal_including_exceptions( + lambda: l1[start_index:end_index], lambda: list(l2[start_index:end_index]) + ) + + # count() + for item in l1 + [999]: + assert_equal_including_exceptions(lambda: l1.count(item), lambda: l2.count(item)) + + # __contains__() + for item in l1 + [999]: + assert_equal_including_exceptions(lambda: item in l1, lambda: item in l2)