Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Multiple datasets and output files support for the evaluate command (#…
Browse files Browse the repository at this point in the history
…5340)

* multiple files evaluation

* Add multiple datasets support for the evaluate command

* Add description for --predictions-output-file for the evaluate command

* Add a test case (mutiple inputs/outputs) for the evaluate command

* update change log

* Update allennlp/commands/evaluate.py

assert error message

Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>

* Update allennlp/commands/evaluate.py

Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>

* Update evaluate.py

* for merging

* Fix  #5340 (comment)

* Fix #5340 (comment)

Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
  • Loading branch information
c4n and AkshitaB authored Sep 1, 2021
1 parent 60213cd commit 48af9d3
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support to evaluate mutiple datasets and produce corresponding output files in the `evaluate` command.
- Added more documentation to the learning rate schedulers to include a sample config object for how to use it.
- Moved the pytorch learning rate schedulers wrappers to their own file called `pytorch_lr_schedulers.py` so that they will have their own documentation page.
- Added a module `allennlp.nn.parallel` with a new base class, `DdpAccelerator`, which generalizes
Expand Down
110 changes: 75 additions & 35 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
from typing import Any, Dict

from copy import deepcopy

from overrides import overrides

from allennlp.commands.subcommand import Subcommand
Expand All @@ -25,25 +27,38 @@
class Evaluate(Subcommand):
@overrides
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Evaluate the specified model + dataset"""
description = """Evaluate the specified model + dataset(s)"""
subparser = parser.add_parser(
self.name, description=description, help="Evaluate the specified model + dataset."
self.name, description=description, help="Evaluate the specified model + dataset(s)."
)

subparser.add_argument("archive_file", type=str, help="path to an archived trained model")

subparser.add_argument(
"input_file", type=str, help="path to the file containing the evaluation data"
"input_file",
type=str,
help=(
"path to the file containing the evaluation data"
' (for mutiple files, put ":" between filenames e.g., input1.txt:input2.txt)'
),
)

subparser.add_argument(
"--output-file", type=str, help="optional path to write the metrics to as JSON"
"--output-file",
type=str,
help=(
"optional path to write the metrics to as JSON"
' (for mutiple files, put ":" between filenames e.g., output1.txt:output2.txt)'
),
)

subparser.add_argument(
"--predictions-output-file",
type=str,
help="optional path to write the predictions to as JSON lines",
help=(
"optional path to write the predictions to as JSON lines"
' (for mutiple files, put ":" between filenames e.g., output1.jsonl:output2.jsonl)'
),
)

subparser.add_argument(
Expand Down Expand Up @@ -123,47 +138,72 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
cuda_device=args.cuda_device,
overrides=args.overrides,
)
config = archive.config
config = deepcopy(archive.config)
prepare_environment(config)
model = archive.model
model.eval()

# Load the evaluation data

dataset_reader = archive.validation_dataset_reader

evaluation_data_path = args.input_file
logger.info("Reading evaluation data from %s", evaluation_data_path)

data_loader_params = config.pop("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.pop("data_loader")
if args.batch_size:
data_loader_params["batch_size"] = args.batch_size
data_loader = DataLoader.from_params(
params=data_loader_params, reader=dataset_reader, data_path=evaluation_data_path
)
# split files
evaluation_data_path_list = args.input_file.split(":")
if args.output_file is not None:
output_file_list = args.output_file.split(":")
assert len(output_file_list) == len(
evaluation_data_path_list
), "The number of `output_file` paths must be equal to the number of datasets being evaluated."
if args.predictions_output_file is not None:
predictions_output_file_list = args.predictions_output_file.split(":")
assert len(predictions_output_file_list) == len(evaluation_data_path_list), (
"The number of `predictions_output_file` paths must be equal"
+ "to the number of datasets being evaluated. "
)

embedding_sources = (
json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {}
)
# output file
output_file_path = None
predictions_output_file_path = None

# embedding sources
if args.extend_vocab:
logger.info("Vocabulary is being extended with test instances.")
model.vocab.extend_from_instances(instances=data_loader.iter_instances())
model.extend_embedder_vocab(embedding_sources)

data_loader.index_with(model.vocab)

metrics = evaluate(
model,
data_loader,
args.cuda_device,
args.batch_weight_key,
output_file=args.output_file,
predictions_output_file=args.predictions_output_file,
)
logger.info("Vocabulary is being extended with embedding sources.")
embedding_sources = (
json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {}
)

for index in range(len(evaluation_data_path_list)):
config = deepcopy(archive.config)
evaluation_data_path = evaluation_data_path_list[index]
if args.output_file is not None:
output_file_path = output_file_list[index]
if args.predictions_output_file is not None:
predictions_output_file_path = predictions_output_file_list[index]

logger.info("Reading evaluation data from %s", evaluation_data_path)
data_loader_params = config.get("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.get("data_loader")
if args.batch_size:
data_loader_params["batch_size"] = args.batch_size
data_loader = DataLoader.from_params(
params=data_loader_params, reader=dataset_reader, data_path=evaluation_data_path
)

if args.extend_vocab:
logger.info("Vocabulary is being extended with test instances.")
model.vocab.extend_from_instances(instances=data_loader.iter_instances())
model.extend_embedder_vocab(embedding_sources)

data_loader.index_with(model.vocab)

metrics = evaluate(
model,
data_loader,
args.cuda_device,
args.batch_weight_key,
output_file=output_file_path,
predictions_output_file=predictions_output_file_path,
)
logger.info("Finished evaluating.")

return metrics
30 changes: 30 additions & 0 deletions tests/commands/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,36 @@ def test_output_file_evaluate_from_args(self):
prediction = json.loads(line.strip())
assert "tags" in prediction

def test_multiple_output_files_evaluate_from_args(self):
output_file = str(self.TEST_DIR / "metrics.json")
predictions_output_file = str(self.TEST_DIR / "predictions.jsonl")
kebab_args = [
"evaluate",
str(
self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz"
),
str(self.FIXTURES_ROOT / "data" / "conll2003.txt")
+ ":"
+ str(self.FIXTURES_ROOT / "data" / "conll2003.txt"),
"--cuda-device",
"-1",
"--output-file",
output_file + ":" + output_file,
"--predictions-output-file",
predictions_output_file + ":" + predictions_output_file,
]
args = self.parser.parse_args(kebab_args)
computed_metrics = evaluate_from_args(args)

with open(output_file, "r") as file:
saved_metrics = json.load(file)
assert computed_metrics == saved_metrics

with open(predictions_output_file, "r") as file:
for line in file:
prediction = json.loads(line.strip())
assert "tags" in prediction

def test_evaluate_works_with_vocab_expansion(self):
archive_path = str(
self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
Expand Down

0 comments on commit 48af9d3

Please sign in to comment.