Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adds a dataset that can be read and written lazily (#5344)
Browse files Browse the repository at this point in the history
* Adds a dataset that can be read and written lazily

This does not work yet. I'm still working on supporting classes.

* This approach might work better.

* Make ShuffledSequence take indices

* Formatting

* Adds failing test

* Fix sparse sequence tests

* Fixes the Sqlite format

* Quality-of-life hack

* Makes an internal string less alarming

* Save the files to the right place

* Formatting

* Fix for SqliteDatasetFormat

* Performance improvement for SqliteSparseSequence

* Changelog

* Global imports
  • Loading branch information
dirkgr authored Aug 23, 2021
1 parent 01e8a35 commit 5dc80a6
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Tango components, to be explored in detail in a later post
- Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it
- Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable
- Added a way for AllenNLP Tango to read and write datasets lazily.

### Fixed

Expand Down
8 changes: 8 additions & 0 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Utilities for working with the local dataset cache.
"""
import string
import weakref
from contextlib import contextmanager
import glob
Expand Down Expand Up @@ -1291,3 +1292,10 @@ def inspect_cache(patterns: List[str] = None, cache_dir: Union[str, Path] = None
f"latest {format_size(size)} from {format_timedelta(td)} ago"
)
print(f"\nTotal size: {format_size(total_size)}")


SAFE_FILENAME_CHARS = frozenset("-_.%s%s" % (string.ascii_letters, string.digits))


def filename_is_safe(filename: str) -> bool:
return all(c in SAFE_FILENAME_CHARS for c in filename)
102 changes: 102 additions & 0 deletions allennlp/common/sqlite_sparse_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import shutil
from os import PathLike
from typing import MutableSequence, Any, Union, Iterable

from sqlitedict import SqliteDict

from allennlp.tango.dataloader import ShuffledSequence


class SqliteSparseSequence(MutableSequence[Any]):
def __init__(self, filename: Union[str, PathLike], read_only: bool = False):
self.table = SqliteDict(filename, "sparse_sequence", flag="r" if read_only else "c")

def __del__(self):
self.close()

def __getitem__(self, i: Union[int, slice]) -> Any:
if isinstance(i, int):
try:
return self.table[str(i)]
except KeyError:
current_length = len(self)
if i >= current_length or current_length <= 0:
raise IndexError("list index out of range")
elif i < 0 < current_length:
return self.__getitem__(i % current_length)
else:
return None
elif isinstance(i, slice):
return ShuffledSequence(self, range(*i.indices(len(self))))
else:
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}")

def __setitem__(self, i: Union[int, slice], value: Any):
if isinstance(i, int):
current_length = len(self)
if i < 0:
i %= current_length
self.table[str(i)] = value
self.table["_len"] = max(i, current_length)
self.table.commit()
else:
raise TypeError(f"list indices must be integers, not {i.__class__.__name__}")

def __delitem__(self, i: Union[int, slice]):
current_length = len(self)
if isinstance(i, int):
if i < 0:
i %= current_length
if i >= current_length:
raise IndexError("list assignment index out of range")
for index in range(i + 1, current_length):
self.table[str(index - 1)] = self.table.get(str(index))
del self.table[str(current_length - 1)]
self.table["_len"] = current_length - 1
self.table.commit()
elif isinstance(i, slice):
# This isn't very efficient for continuous slices.
for index in reversed(range(*i.indices(current_length))):
del self[index]
else:
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}")

def extend(self, values: Iterable[Any]) -> None:
current_length = len(self)
for index, value in enumerate(values):
self.table[str(index + current_length)] = value
self.table["_len"] = current_length + index + 1
self.table.commit()

def insert(self, i: int, value: Any) -> None:
current_length = len(self)
for index in reversed(range(i, current_length)):
self.table[str(index + 1)] = self.table.get(str(index))
self.table[str(i)] = value
self.table["_len"] = current_length + 1
self.table.commit()

def __len__(self) -> int:
try:
return self.table["_len"]
except KeyError:
return 0

def clear(self) -> None:
self.table.clear()
self.table.commit()

def close(self) -> None:
if self.table is not None:
self.table.close()
self.table = None

def copy_to(self, target: Union[str, PathLike]):
try:
os.link(self.table.filename, target)
except OSError as e:
if e.errno == 18: # Cross-device link
shutil.copy(self.table.filename, target)
else:
raise
5 changes: 1 addition & 4 deletions allennlp/data/fields/index_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ def __init__(self, index: int, sequence_field: SequenceField) -> None:

@overrides
def get_padding_lengths(self) -> Dict[str, int]:

return {}

@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:

tensor = torch.LongTensor([self.sequence_index])
return tensor
return torch.LongTensor([self.sequence_index])

@overrides
def empty_field(self):
Expand Down
1 change: 0 additions & 1 deletion allennlp/data/fields/label_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def get_padding_lengths(self) -> Dict[str, int]:

@overrides
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor:

tensor = torch.tensor(self._label_id, dtype=torch.long)
return tensor

Expand Down
14 changes: 8 additions & 6 deletions allennlp/tango/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ class ShuffledSequence(abc.Sequence):
are undefined.
"""

def __init__(self, inner_sequence: Sequence):
def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None):
self.inner = inner_sequence
self.indices = list(range(len(inner_sequence)))
random.shuffle(self.indices)
self.indices: Sequence[int]
if indices is None:
self.indices = list(range(len(inner_sequence)))
random.shuffle(self.indices)
else:
self.indices = indices

def __len__(self) -> int:
return len(self.inner)
Expand All @@ -106,9 +110,7 @@ def __getitem__(self, i: Union[int, slice]):
if isinstance(i, int):
return self.inner[self.indices[i]]
else:
result = ShuffledSequence(self.inner)
result.indices = self.indices[i]
return result
return ShuffledSequence(self.inner, self.indices[i])

def __contains__(self, item) -> bool:
return self.inner.__contains__(item)
Expand Down
45 changes: 45 additions & 0 deletions allennlp/tango/sqlite_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gzip
import pathlib
from os import PathLike
from typing import Union

import dill

from allennlp.common.file_utils import filename_is_safe
from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence
from allennlp.tango.dataset import DatasetDict
from allennlp.tango.format import Format


@Format.register("sqlite")
class SqliteDictFormat(Format[DatasetDict]):
VERSION = 2

def write(self, artifact: DatasetDict, dir: Union[str, PathLike]):
dir = pathlib.Path(dir)
with gzip.open(dir / "vocab.dill.gz", "wb") as f:
dill.dump(artifact.vocab, f)
with gzip.open(dir / "metadata.dill.gz", "wb") as f:
dill.dump(artifact.metadata, f)
for split_name, split in artifact.splits.items():
filename = f"{split_name}.sqlite"
if not filename_is_safe(filename):
raise ValueError(f"{split_name} is not a valid name for a split.")
if isinstance(split, SqliteSparseSequence):
split.copy_to(filename)
else:
(dir / filename).unlink(missing_ok=True)
sqlite = SqliteSparseSequence(dir / filename)
sqlite.extend(split)

def read(self, dir: Union[str, PathLike]) -> DatasetDict:
dir = pathlib.Path(dir)
with gzip.open(dir / "vocab.dill.gz", "rb") as f:
vocab = dill.load(f)
with gzip.open(dir / "metadata.dill.gz", "rb") as f:
metadata = dill.load(f)
splits = {
filename.stem: SqliteSparseSequence(filename, read_only=True)
for filename in dir.glob("*.sqlite")
}
return DatasetDict(vocab=vocab, metadata=metadata, splits=splits)
2 changes: 2 additions & 0 deletions allennlp/tango/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def __init__(

if step_format is None:
self.format = self.FORMAT
if isinstance(self.format, type):
self.format = self.format()
else:
self.format = step_format

Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/fbeta_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class FBetaMeasure(Metric):
alters 'macro' to account for label imbalance; it can result in an
F-score that is not between precision and recall.
labels: `list`, optional
labels : `list`, optional
The set of labels to include and their order if `average is None`.
Labels present in the data can be excluded, for example to calculate a
multi-class average ignoring a majority negative class. Labels not present
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"datasets>=1.2.1,<2.0",
"dill",
"base58",
"sqlitedict",
"google-cloud-storage>=1.38.0,<1.43.0",
],
entry_points={"console_scripts": ["allennlp=allennlp.__main__:run"]},
Expand Down
23 changes: 23 additions & 0 deletions tests/common/sqlite_sparse_sequence_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from tempfile import TemporaryDirectory

from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence


def test_sqlite_sparse_sequence():
with TemporaryDirectory(prefix="test_sparse_sequence-") as temp_dir:
s = SqliteSparseSequence(os.path.join(temp_dir, "test.sqlite"))
assert len(s) == 0
s.append("one")
assert len(s) == 1
s.extend(["two", "three"])
s.insert(1, "two")
assert s[1] == "two"
assert s.count("two") == 2
ss = s[1:3]
assert list(ss) == ["two", "two"]
del s[1:3]
assert len(s) == 2
assert s[-1] == "three"
s.clear()
assert len(s) == 0

0 comments on commit 5dc80a6

Please sign in to comment.