Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
new data loading (#4497)
Browse files Browse the repository at this point in the history
* first implementation

* update docstrings

* fixes

* fix sharding logic

* clean up DatasetReader

* fix samplers

* fixes

* fixes

* patch models for now

* more fixes

* fix linting error

* fix model test case

* some fixes

* fix linting err

* updates

* rename dataloader -> data_loader

* fixes

* more JoinableQueue

* set daemon=True

* fixes

* fix

* fixes

* fix

* update shuffle logic

* load instances right away when not lazy

* add tqdm when num_workers <= 0

* apply_token_indexers

* fix bug causing high mem usage

* address some of @dirkgr's comments

* fix lazy

* use sensible default for max_batches_in_mem

* ensure workers terminated on err

* fix

* start adding some tests

* more tests

* add some more tests

* address most of Matt's comments

* update PyTorchDataLoader test

* get rid of lazy option

* fix linting

* update docs, change max_batches_per_epoch to max_instances_per_epcoh

* update CHANGELOG

* fix drop_last validation

* fix py2md test fixture

* handle drop_last

* update docs

* implement sharding for most readers

* fix worker init fn

* limit tqdm output

* fixes
  • Loading branch information
epwalsh authored Aug 20, 2020
1 parent 6f82005 commit e74a736
Show file tree
Hide file tree
Showing 58 changed files with 1,818 additions and 1,474 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ jobs:
run: |
git clone https://github.com/allenai/allennlp-models.git
cd allennlp-models
# TODO: remove this next line when it's no longer necessary.
git checkout vision
pip install --upgrade --upgrade-strategy eager -e . -r dev-requirements.txt
- name: Run models tests
Expand Down
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- A new high-performance default `DataLoader`: `MultiProcessDataLoading`.

### Changed

- `DatasetReader`s are now always lazy. This means there is no `lazy` parameter in the base
class, and the `_read()` method should always be a generator.
- The `DataLoader` now decides whether to load instances lazily or not.
With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with
the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting.
- Added a workflow to GitHub Actions that will automatically close unassigned stale issues and
ping the assignees of assigned stale issues.

Expand All @@ -26,7 +35,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed problem with automatically detecting whether tokenization is necessary.
This affected primarily the Roberta SST model.


## [v1.1.0rc2](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc2) - 2020-07-31

### Changed
Expand Down
21 changes: 12 additions & 9 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,29 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
dataset_reader = DatasetReader.from_params(validation_dataset_reader_params)
else:
dataset_reader = DatasetReader.from_params(config.pop("dataset_reader"))

evaluation_data_path = args.input_file
logger.info("Reading evaluation data from %s", evaluation_data_path)
instances = dataset_reader.read(evaluation_data_path)

data_loader_params = config.pop("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.pop("data_loader")
if args.batch_size:
data_loader_params["batch_size"] = args.batch_size
data_loader = DataLoader.from_params(
params=data_loader_params, reader=dataset_reader, data_path=evaluation_data_path
)

embedding_sources = (
json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {}
)

if args.extend_vocab:
logger.info("Vocabulary is being extended with test instances.")
model.vocab.extend_from_instances(instances=instances)
model.vocab.extend_from_instances(instances=data_loader.iter_instances())
model.extend_embedder_vocab(embedding_sources)

instances.index_with(model.vocab)
data_loader_params = config.pop("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.pop("data_loader")
if args.batch_size:
data_loader_params["batch_size"] = args.batch_size
data_loader = DataLoader.from_params(dataset=instances, params=data_loader_params)
data_loader.index_with(model.vocab)

metrics = evaluate(model, data_loader, args.cuda_device, args.batch_weight_key)

Expand Down
20 changes: 9 additions & 11 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.util import prepare_environment
from allennlp.data import Vocabulary
from allennlp.data import DataLoader
from allennlp.models import Model
from allennlp.training import GradientDescentTrainer, Trainer
from allennlp.training.util import create_serialization_dir, datasets_from_params
from allennlp.training.util import create_serialization_dir, data_loaders_from_params

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,11 +160,11 @@ def find_learning_rate_model(
# See https://github.com/allenai/allennlp/issues/3658
assert not distributed_params, "find-lr is not compatible with DistributedDataParallel."

all_datasets = datasets_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
all_data_loaders = data_loaders_from_params(params)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_data_loaders))

for dataset in datasets_for_vocab_creation:
if dataset not in all_datasets:
if dataset not in all_data_loaders:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

logger.info(
Expand All @@ -176,16 +175,15 @@ def find_learning_rate_model(
params.pop("vocabulary", {}),
instances=(
instance
for key, dataset in all_datasets.items()
for instance in dataset
for key, data_loader in all_data_loaders.items()
if key in datasets_for_vocab_creation
for instance in data_loader.iter_instances()
),
)

train_data = all_datasets["train"]
train_data.index_with(vocab)
model = Model.from_params(vocab=vocab, params=params.pop("model"))
data_loader = DataLoader.from_params(dataset=train_data, params=params.pop("data_loader"))

all_data_loaders["train"].index_with(vocab)

trainer_params = params.pop("trainer")

Expand All @@ -202,7 +200,7 @@ def find_learning_rate_model(
trainer: GradientDescentTrainer = Trainer.from_params( # type: ignore
model=model,
serialization_dir=serialization_dir,
data_loader=data_loader,
data_loader=all_data_loaders["train"],
params=trainer_params,
)

Expand Down
85 changes: 50 additions & 35 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,64 +560,98 @@ def from_partial_objects(
In a typical AllenNLP configuration file, this parameter does not get an entry as a
top-level key, it gets passed in separately.
local_rank: `int`
The process index that is initialized using the GPU device id.
In a typical AllenNLP configuration file, this parameter does not get an entry as a
top-level key, it gets passed in separately.
dataset_reader: `DatasetReader`
The `DatasetReader` that will be used for training and (by default) for validation.
train_data_path: `str`
The file (or directory) that will be passed to `dataset_reader.read()` to construct the
training data.
model: `Lazy[Model]`
The model that we will train. This is lazy because it depends on the `Vocabulary`;
after constructing the vocabulary we call `model.construct(vocab=vocabulary)`.
data_loader: `Lazy[DataLoader]`
The data_loader we use to batch instances from the dataset reader at training and (by
default) validation time. This is lazy because it takes a dataset in it's constructor.
trainer: `Lazy[Trainer]`
The `Trainer` that actually implements the training loop. This is a lazy object because
it depends on the model that's going to be trained.
vocabulary: `Lazy[Vocabulary]`, optional (default=`None`)
The `Vocabulary` that we will use to convert strings in the data to integer ids (and
possibly set sizes of embedding matrices in the `Model`). By default we construct the
vocabulary from the instances that we read.
datasets_for_vocab_creation: `List[str]`, optional (default=`None`)
If you pass in more than one dataset but don't want to use all of them to construct a
vocabulary, you can pass in this key to limit it. Valid entries in the list are
"train", "validation" and "test".
validation_dataset_reader: `DatasetReader`, optional (default=`None`)
If given, we will use this dataset reader for the validation data instead of
`dataset_reader`.
validation_data_path: `str`, optional (default=`None`)
If given, we will use this data for computing validation metrics and early stopping.
validation_data_loader: `Lazy[DataLoader]`, optional (default=`None`)
If given, the data_loader we use to batch instances from the dataset reader at
validation and test time. This is lazy because it takes a dataset in it's constructor.
test_data_path: `str`, optional (default=`None`)
If given, we will use this as test data. This makes it available for vocab creation by
default, but nothing else.
evaluate_on_test: `bool`, optional (default=`False`)
If given, we will evaluate the final model on this data at the end of training. Note
that we do not recommend using this for actual test data in every-day experimentation;
you should only very rarely evaluate your model on actual test data.
batch_weight_key: `str`, optional (default=`""`)
The name of metric used to weight the loss on a per-batch basis. This is only used
during evaluation on final test data, if you've specified `evaluate_on_test=True`.
"""
# Train data loader.
data_loaders: Dict[str, DataLoader] = {
"train": data_loader.construct(reader=dataset_reader, data_path=train_data_path)
}

datasets = training_util.read_all_datasets(
train_data_path=train_data_path,
dataset_reader=dataset_reader,
validation_dataset_reader=validation_dataset_reader,
validation_data_path=validation_data_path,
test_data_path=test_data_path,
)
# Validation data loader.
if validation_data_path is not None:
validation_dataset_reader = validation_dataset_reader or dataset_reader
validation_data_loader_ = validation_data_loader.construct(
reader=validation_dataset_reader, data_path=validation_data_path
)
if validation_data_loader_ is None:
validation_data_loader_ = data_loader.construct(
reader=validation_dataset_reader, data_path=validation_data_path
)
data_loaders["validation"] = validation_data_loader_

# Test data loader.
if test_data_path is not None:
test_dataset_reader = validation_dataset_reader or dataset_reader
test_data_loader_ = validation_data_loader.construct(
reader=test_dataset_reader, data_path=test_data_path
)
if test_data_loader_ is None:
test_data_loader_ = data_loader.construct(
reader=test_dataset_reader, data_path=test_data_path
)
data_loaders["test"] = test_data_loader_

if datasets_for_vocab_creation:
for key in datasets_for_vocab_creation:
if key not in datasets:
if key not in data_loaders:
raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {key}")

logger.info(
Expand All @@ -627,9 +661,9 @@ def from_partial_objects(

instance_generator = (
instance
for key, dataset in datasets.items()
for key, data_loader in data_loaders.items()
if datasets_for_vocab_creation is None or key in datasets_for_vocab_creation
for instance in dataset
for instance in data_loader.iter_instances()
)

vocabulary_ = vocabulary.construct(instances=instance_generator)
Expand All @@ -646,42 +680,23 @@ def from_partial_objects(
vocabulary_path = os.path.join(serialization_dir, "vocabulary")
vocabulary_.save_to_files(vocabulary_path)

for dataset in datasets.values():
dataset.index_with(model_.vocab)

data_loader_ = data_loader.construct(dataset=datasets["train"])
validation_data = datasets.get("validation")
if validation_data is not None:
# Because of the way Lazy[T] works, we can't check it's existence
# _before_ we've tried to construct it. It returns None if it is not
# present, so we try to construct it first, and then afterward back off
# to the data_loader configuration used for training if it returns None.
validation_data_loader_ = validation_data_loader.construct(dataset=validation_data)
if validation_data_loader_ is None:
validation_data_loader_ = data_loader.construct(dataset=validation_data)
else:
validation_data_loader_ = None

test_data = datasets.get("test")
if test_data is not None:
test_data_loader = validation_data_loader.construct(dataset=test_data)
if test_data_loader is None:
test_data_loader = data_loader.construct(dataset=test_data)
else:
test_data_loader = None
for data_loader_ in data_loaders.values():
data_loader_.index_with(model_.vocab)

# We don't need to pass serialization_dir and local_rank here, because they will have been
# passed through the trainer by from_params already, because they were keyword arguments to
# construct this class in the first place.
trainer_ = trainer.construct(
model=model_, data_loader=data_loader_, validation_data_loader=validation_data_loader_,
model=model_,
data_loader=data_loaders["train"],
validation_data_loader=data_loaders.get("validation"),
)

return cls(
serialization_dir=serialization_dir,
model=model_,
trainer=trainer_,
evaluation_data_loader=test_data_loader,
evaluation_data_loader=data_loaders.get("test"),
evaluate_on_test=evaluate_on_test,
batch_weight_key=batch_weight_key,
)
Expand Down
26 changes: 13 additions & 13 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def set_up_model(self, param_file, dataset_file):

reader = DatasetReader.from_params(params["dataset_reader"])
# The dataset reader might be lazy, but a lazy list here breaks some of our tests.
instances = reader.read(str(dataset_file))
instances = list(reader.read(str(dataset_file)))
# Use parameters for vocabulary if they are present in the config file, so that choices like
# "non_padded_namespaces", "min_count" etc. can be set if needed.
if "vocabulary" in params:
Expand All @@ -39,12 +39,11 @@ def set_up_model(self, param_file, dataset_file):
vocab = Vocabulary.from_instances(instances)
self.vocab = vocab
self.instances = instances
self.instances.index_with(vocab)
self.model = Model.from_params(vocab=self.vocab, params=params["model"])

# TODO(joelgrus) get rid of these
# (a lot of the model tests use them, so they'll have to be changed)
self.dataset = Batch(list(self.instances))
self.dataset = Batch(self.instances)
self.dataset.index_instances(self.vocab)

def ensure_model_can_train_save_and_load(
Expand Down Expand Up @@ -119,21 +118,22 @@ def ensure_model_can_train_save_and_load(
params = Params.from_file(param_file, params_overrides=overrides)
reader = DatasetReader.from_params(params["dataset_reader"])

print("Reading with original model")
model_dataset = reader.read(params["validation_data_path"])
model_dataset.index_with(model.vocab)

print("Reading with loaded model")
loaded_dataset = reader.read(params["validation_data_path"])
loaded_dataset.index_with(loaded_model.vocab)

# Need to duplicate params because DataLoader.from_params will consume.
data_loader_params = params["data_loader"]
data_loader_params["shuffle"] = False
data_loader_params2 = Params(copy.deepcopy(data_loader_params.as_dict()))

data_loader = DataLoader.from_params(dataset=model_dataset, params=data_loader_params)
data_loader2 = DataLoader.from_params(dataset=loaded_dataset, params=data_loader_params2)
print("Reading with original model")
data_loader = DataLoader.from_params(
params=data_loader_params, reader=reader, data_path=params["validation_data_path"]
)
data_loader.index_with(model.vocab)

print("Reading with loaded model")
data_loader2 = DataLoader.from_params(
params=data_loader_params2, reader=reader, data_path=params["validation_data_path"]
)
data_loader2.index_with(loaded_model.vocab)

# We'll check that even if we index the dataset with each model separately, we still get
# the same result out.
Expand Down
11 changes: 7 additions & 4 deletions allennlp/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from allennlp.data.dataloader import DataLoader, PyTorchDataLoader, allennlp_collate
from allennlp.data.dataset_readers.dataset_reader import (
DatasetReader,
from allennlp.data.data_loaders import (
DataLoader,
PyTorchDataLoader,
TensorDict,
allennlp_collate,
AllennlpDataset,
AllennlpLazyDataset,
)
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields.field import DataArray, Field
from allennlp.data.fields.text_field import TextFieldTensors
from allennlp.data.instance import Instance
from allennlp.data.samplers import BatchSampler, Sampler
from allennlp.data.samplers import BatchSampler, PyTorchSampler, PyTorchBatchSampler
from allennlp.data.token_indexers.token_indexer import TokenIndexer, IndexedTokenList
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.tokenizer import Tokenizer
Expand Down
8 changes: 8 additions & 0 deletions allennlp/data/data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from allennlp.data.data_loaders.data_loader import DataLoader, TensorDict, allennlp_collate
from allennlp.data.data_loaders.multi_process_data_loader import MultiProcessDataLoader
from allennlp.data.data_loaders.pytorch_data_loader import (
PyTorchDataLoader,
allennlp_worker_init_fn,
AllennlpDataset,
AllennlpLazyDataset,
)
Loading

0 comments on commit e74a736

Please sign in to comment.