Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Dataset remix #5372

Merged
merged 39 commits into from
Aug 25, 2021
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
10c5479
Adds a dataset that can be read and written lazily
dirkgr Aug 7, 2021
36f9b67
This approach might work better.
dirkgr Aug 7, 2021
e74540d
Make ShuffledSequence take indices
dirkgr Aug 7, 2021
0eb53bf
Formatting
dirkgr Aug 7, 2021
dcedfd5
Adds failing test
dirkgr Aug 7, 2021
36948ce
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 11, 2021
44eccf9
Fix sparse sequence tests
dirkgr Aug 11, 2021
f305de7
Fixes the Sqlite format
dirkgr Aug 11, 2021
61f8810
Quality-of-life hack
dirkgr Aug 11, 2021
989f15c
Makes an internal string less alarming
dirkgr Aug 11, 2021
9c461b7
Save the files to the right place
dirkgr Aug 11, 2021
15e0be4
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 18, 2021
ca26abe
Formatting
dirkgr Aug 19, 2021
f2f0a34
Merge remote-tracking branch 'origin/main' into TangoBigData
dirkgr Aug 19, 2021
bb572b3
Fix for SqliteDatasetFormat
dirkgr Aug 20, 2021
6953d7d
Performance improvement for SqliteSparseSequence
dirkgr Aug 20, 2021
3f99be7
Changelog
dirkgr Aug 20, 2021
d69ea38
Merge branch 'main' into TangoBigData
dirkgr Aug 20, 2021
d58a52f
Global imports
dirkgr Aug 20, 2021
104777d
More Sequence classes
dirkgr Aug 21, 2021
b6b5f05
Say DatasetDict when we mean DatasetDict
dirkgr Aug 21, 2021
05c4dd6
Test for the sequences
dirkgr Aug 21, 2021
4304a93
Use the step name correctly in the error message
dirkgr Aug 21, 2021
d6cb8ab
Use and consume step_name correctly in Step.from_params()
dirkgr Aug 21, 2021
fd305a6
Uncacheable steps don't get cached even if they have a name
dirkgr Aug 21, 2021
3ae61eb
Adds a step that can remix a dataset
dirkgr Aug 21, 2021
2004fd2
Improve log message
dirkgr Aug 21, 2021
b0c3626
Fix relative import
dirkgr Aug 21, 2021
fcf651f
Changelog
dirkgr Aug 21, 2021
aa82e3d
Merge branch 'main' into DatasetRemix
dirkgr Aug 23, 2021
ca5cad3
Adds documentation
dirkgr Aug 23, 2021
d5f11f4
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr Aug 23, 2021
c52b050
Give the option of changing a det_hash simply()
dirkgr Aug 23, 2021
a32c7f2
Tix fypo
dirkgr Aug 23, 2021
6cccd64
Adds ability to shuffle datasets
dirkgr Aug 24, 2021
765575d
Test for det_hash
dirkgr Aug 24, 2021
c69df7e
Merge branch 'main' into DatasetRemix
dirkgr Aug 24, 2021
451e4ee
We don't use relative imports
dirkgr Aug 25, 2021
1d71b69
Merge branch 'DatasetRemix' of https://github.com/allenai/allennlp in…
dirkgr Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adds a step that can remix a dataset
  • Loading branch information
dirkgr committed Aug 21, 2021
commit 3ae61eb6fba7390837f935beadd214f24e88fa47
45 changes: 45 additions & 0 deletions allennlp/tango/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
"""

import itertools
import re
from dataclasses import dataclass, field
from typing import Mapping, Any, Optional, Sequence, Dict

from allennlp.data import Vocabulary, DatasetReader, Instance
from allennlp.tango.step import Step
from allennlp.common.sequences import SlicedSequence, ConcatenatedSequence
from tqdm import tqdm


Expand Down Expand Up @@ -72,3 +74,46 @@ def run(self, reader: DatasetReader, splits: Dict[str, str]) -> DatasetDict: #
instance.index_fields(vocab)

return DatasetDict(splits=instances_map, vocab=vocab)


@Step.register("dataset_remix")
class DatasetRemixStep(Step):
"""
This step can remix splits in a dataset into new splits.
"""

DETERMINISTIC = True
CACHEABLE = False # This is so fast it's not worth caching.
VERSION = "001"

def run( # type: ignore
self, input: DatasetDict, new_splits: Dict[str, str], keep_old_splits: bool = True
) -> DatasetDict:
def get_slice(split_name: str) -> Sequence[Any]:
slice_match = re.match(r"(.*)\[([0123456789:]*)]", split_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work for something like train[:50000] + dev[:10000]. Is it supposed to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will work. This function is only called on the parts after .split("+").

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, and train[:50000] in that case is interpreted as it should be?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It supports full Python slice syntax on line 98. In condensed form, it does slice(*match.split(":")).

if slice_match is None:
return input[split_name]
else:
split_name = slice_match[1]
slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")]
return SlicedSequence(input[split_name], slice(*slice_args))

def parse_split_spec(split_spec: str):
parts = [get_slice(name.strip()) for name in split_spec.split("+")]
if len(parts) == 1:
return parts[0]
else:
return ConcatenatedSequence(*parts)

if keep_old_splits:
result = dict(input.splits.items())
else:
result = {}
result.update(
{
new_split_name: parse_split_spec(new_split_spec)
for new_split_name, new_split_spec in new_splits.items()
}
)

return DatasetDict(vocab=input.vocab, metadata=input.metadata, splits=result)