This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a dataset that can be read and written lazily (#5344)
* Adds a dataset that can be read and written lazily This does not work yet. I'm still working on supporting classes. * This approach might work better. * Make ShuffledSequence take indices * Formatting * Adds failing test * Fix sparse sequence tests * Fixes the Sqlite format * Quality-of-life hack * Makes an internal string less alarming * Save the files to the right place * Formatting * Fix for SqliteDatasetFormat * Performance improvement for SqliteSparseSequence * Changelog * Global imports
- Loading branch information
Showing
11 changed files
with
192 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import os | ||
import shutil | ||
from os import PathLike | ||
from typing import MutableSequence, Any, Union, Iterable | ||
|
||
from sqlitedict import SqliteDict | ||
|
||
from allennlp.tango.dataloader import ShuffledSequence | ||
|
||
|
||
class SqliteSparseSequence(MutableSequence[Any]): | ||
def __init__(self, filename: Union[str, PathLike], read_only: bool = False): | ||
self.table = SqliteDict(filename, "sparse_sequence", flag="r" if read_only else "c") | ||
|
||
def __del__(self): | ||
self.close() | ||
|
||
def __getitem__(self, i: Union[int, slice]) -> Any: | ||
if isinstance(i, int): | ||
try: | ||
return self.table[str(i)] | ||
except KeyError: | ||
current_length = len(self) | ||
if i >= current_length or current_length <= 0: | ||
raise IndexError("list index out of range") | ||
elif i < 0 < current_length: | ||
return self.__getitem__(i % current_length) | ||
else: | ||
return None | ||
elif isinstance(i, slice): | ||
return ShuffledSequence(self, range(*i.indices(len(self)))) | ||
else: | ||
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}") | ||
|
||
def __setitem__(self, i: Union[int, slice], value: Any): | ||
if isinstance(i, int): | ||
current_length = len(self) | ||
if i < 0: | ||
i %= current_length | ||
self.table[str(i)] = value | ||
self.table["_len"] = max(i, current_length) | ||
self.table.commit() | ||
else: | ||
raise TypeError(f"list indices must be integers, not {i.__class__.__name__}") | ||
|
||
def __delitem__(self, i: Union[int, slice]): | ||
current_length = len(self) | ||
if isinstance(i, int): | ||
if i < 0: | ||
i %= current_length | ||
if i >= current_length: | ||
raise IndexError("list assignment index out of range") | ||
for index in range(i + 1, current_length): | ||
self.table[str(index - 1)] = self.table.get(str(index)) | ||
del self.table[str(current_length - 1)] | ||
self.table["_len"] = current_length - 1 | ||
self.table.commit() | ||
elif isinstance(i, slice): | ||
# This isn't very efficient for continuous slices. | ||
for index in reversed(range(*i.indices(current_length))): | ||
del self[index] | ||
else: | ||
raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}") | ||
|
||
def extend(self, values: Iterable[Any]) -> None: | ||
current_length = len(self) | ||
for index, value in enumerate(values): | ||
self.table[str(index + current_length)] = value | ||
self.table["_len"] = current_length + index + 1 | ||
self.table.commit() | ||
|
||
def insert(self, i: int, value: Any) -> None: | ||
current_length = len(self) | ||
for index in reversed(range(i, current_length)): | ||
self.table[str(index + 1)] = self.table.get(str(index)) | ||
self.table[str(i)] = value | ||
self.table["_len"] = current_length + 1 | ||
self.table.commit() | ||
|
||
def __len__(self) -> int: | ||
try: | ||
return self.table["_len"] | ||
except KeyError: | ||
return 0 | ||
|
||
def clear(self) -> None: | ||
self.table.clear() | ||
self.table.commit() | ||
|
||
def close(self) -> None: | ||
if self.table is not None: | ||
self.table.close() | ||
self.table = None | ||
|
||
def copy_to(self, target: Union[str, PathLike]): | ||
try: | ||
os.link(self.table.filename, target) | ||
except OSError as e: | ||
if e.errno == 18: # Cross-device link | ||
shutil.copy(self.table.filename, target) | ||
else: | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import gzip | ||
import pathlib | ||
from os import PathLike | ||
from typing import Union | ||
|
||
import dill | ||
|
||
from allennlp.common.file_utils import filename_is_safe | ||
from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence | ||
from allennlp.tango.dataset import DatasetDict | ||
from allennlp.tango.format import Format | ||
|
||
|
||
@Format.register("sqlite") | ||
class SqliteDictFormat(Format[DatasetDict]): | ||
VERSION = 2 | ||
|
||
def write(self, artifact: DatasetDict, dir: Union[str, PathLike]): | ||
dir = pathlib.Path(dir) | ||
with gzip.open(dir / "vocab.dill.gz", "wb") as f: | ||
dill.dump(artifact.vocab, f) | ||
with gzip.open(dir / "metadata.dill.gz", "wb") as f: | ||
dill.dump(artifact.metadata, f) | ||
for split_name, split in artifact.splits.items(): | ||
filename = f"{split_name}.sqlite" | ||
if not filename_is_safe(filename): | ||
raise ValueError(f"{split_name} is not a valid name for a split.") | ||
if isinstance(split, SqliteSparseSequence): | ||
split.copy_to(filename) | ||
else: | ||
(dir / filename).unlink(missing_ok=True) | ||
sqlite = SqliteSparseSequence(dir / filename) | ||
sqlite.extend(split) | ||
|
||
def read(self, dir: Union[str, PathLike]) -> DatasetDict: | ||
dir = pathlib.Path(dir) | ||
with gzip.open(dir / "vocab.dill.gz", "rb") as f: | ||
vocab = dill.load(f) | ||
with gzip.open(dir / "metadata.dill.gz", "rb") as f: | ||
metadata = dill.load(f) | ||
splits = { | ||
filename.stem: SqliteSparseSequence(filename, read_only=True) | ||
for filename in dir.glob("*.sqlite") | ||
} | ||
return DatasetDict(vocab=vocab, metadata=metadata, splits=splits) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
from tempfile import TemporaryDirectory | ||
|
||
from allennlp.common.sqlite_sparse_sequence import SqliteSparseSequence | ||
|
||
|
||
def test_sqlite_sparse_sequence(): | ||
with TemporaryDirectory(prefix="test_sparse_sequence-") as temp_dir: | ||
s = SqliteSparseSequence(os.path.join(temp_dir, "test.sqlite")) | ||
assert len(s) == 0 | ||
s.append("one") | ||
assert len(s) == 1 | ||
s.extend(["two", "three"]) | ||
s.insert(1, "two") | ||
assert s[1] == "two" | ||
assert s.count("two") == 2 | ||
ss = s[1:3] | ||
assert list(ss) == ["two", "two"] | ||
del s[1:3] | ||
assert len(s) == 2 | ||
assert s[-1] == "three" | ||
s.clear() | ||
assert len(s) == 0 |