diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index 46f5b7070..369844354 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -122,7 +122,7 @@ def register_parser(parent_subparsers): return parser -def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]: +def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]: """ Validate that the provided *records* all have the same fields. Uses the keys of the first record to check against all other records. @@ -131,6 +131,10 @@ def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]: ---------- records: iterable of dict + subcmd_name: str + The name of the subcommand whose output is being validated; used in + error messages displayed to the user. + is_input: bool Whether the provided records come directly from user provided input """ @@ -140,8 +144,8 @@ def validate_records(records: Iterable[dict], is_input: bool) -> Iterable[dict]: else: # Hopefully users should not run into this error as it means we are # not uniformly adding/removing fields from records - error_message += dedent("""\ - Something unexpected happened during the augur curate command. + error_message += dedent(f"""\ + Something unexpected happened during the augur curate {subcmd_name} command. To report this, please open a new issue including the original command: """) @@ -213,14 +217,17 @@ def run(args): input files can be provided via the command line options `--metadata` and `--fasta`. See the command's help message for more details.""")) + # Get the name of the subcmd being run + subcmd_name = args.subcommand + # Validate records have the same input fields - validated_input_records = validate_records(records, True) + validated_input_records = validate_records(records, subcmd_name, True) # Run subcommand to get modified records modified_records = getattr(args, SUBCOMMAND_ATTRIBUTE).run(args, validated_input_records) # Validate modified records have the same output fields - validated_output_records = validate_records(modified_records, False) + validated_output_records = validate_records(modified_records, subcmd_name, False) # Output modified records # First output FASTA, since the write fasta function yields the records again diff --git a/tests/io/test_curate_validate_records.py b/tests/io/test_curate_validate_records.py new file mode 100644 index 000000000..72f5083f4 --- /dev/null +++ b/tests/io/test_curate_validate_records.py @@ -0,0 +1,47 @@ +import pytest +from augur.curate import validate_records +from augur.errors import AugurError + + +@pytest.fixture +def good_records(): + return [ + {"geo_loc_name": "Canada/Vancouver"}, + {"geo_loc_name": "Canada/Vancouver"}, + ] + + +@pytest.fixture +def bad_records(): + return [ + {"geo_loc_name": "Canada/Vancouver"}, + {"geo_loc_name2": "Canada/Vancouver"}, + ] + + +class TestCurateValidateRecords: + def test_validate_input(self, good_records): + validated_records = validate_records(good_records, "test_subcmd", True) + assert list(validated_records) == good_records, "good input records validate" + + def test_validate_output(self, good_records): + validated_records = validate_records(good_records, "test_subcmd", False) + + assert list(validated_records) == good_records, "good output records validate" + + def test_validate_bad_records(self, bad_records): + with pytest.raises(AugurError) as e: + list(validate_records(bad_records, "test_subcmd", True)) + assert str(e.value).startswith( + "Records do not have the same fields!" + ), "bad input records throw exception with expected message" + + def test_validate_bad_output(self, bad_records): + with pytest.raises(AugurError) as e: + list(validate_records(bad_records, "test_subcmd", False)) + assert str(e.value).startswith( + "Records do not have the same fields!" + ), "bad output records throw exception with expected message" + assert ( + "test_subcmd" in str(e.value) + ), "bad output records throw exception with subcmd name in the message"