diff --git a/augur/filter.py b/augur/filter.py index 4c1babab7..0dabcfb3b 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -29,13 +29,20 @@ def read_vcf(filename): return sequences, sequences.copy() -def write_vcf(compressed, input_file, output_file, dropped_samps): - #Read in/write out according to file ending - inCall = "--gzvcf" if compressed else "--vcf" - outCall = "| gzip -c" if output_file.lower().endswith('.gz') else "" +def write_vcf(input_filename, output_filename, dropped_samps): + if _filename_gz(input_filename): + input_arg = "--gzvcf" + else: + input_arg = "--vcf" + + if _filename_gz(output_filename): + output_pipe = "| gzip -c" + else: + output_pipe = "" - toDrop = " ".join(["--remove-indv "+shquote(s) for s in dropped_samps]) - call = ["vcftools", toDrop, inCall, shquote(input_file), "--recode --stdout", outCall, ">", shquote(output_file)] + drop_args = ["--remove-indv " + shquote(s) for s in dropped_samps] + + call = ["vcftools"] + drop_args + [input_arg, shquote(input_filename), "--recode --stdout", output_pipe, ">", shquote(output_filename)] print("Filtering samples using VCFTools with the call:") print(" ".join(call)) @@ -339,7 +346,7 @@ def run(args): if len(dropped_samps) == len(all_seq): #All samples have been dropped! Stop run, warn user. print("ERROR: All samples have been dropped! Check filter rules and metadata file format.") return 1 - write_vcf(is_compressed, args.sequences, args.output, dropped_samps) + write_vcf(args.sequences, args.output, dropped_samps) else: seq_to_keep = [seq for id,seq in seqs.items() if id in seq_keep] @@ -370,3 +377,7 @@ def run(args): print("\t%i sequences were added back because of '%s'" % (num_included_by_metadata, args.include_where)) print("%i sequences have been written out to %s" % (len(seq_keep), args.output)) + + +def _filename_gz(filename): + return filename.lower().endswith(".gz") diff --git a/tests/test_filter.py b/tests/test_filter.py index 860401a51..10a5c6949 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,4 +1,5 @@ import augur.filter + import pytest @@ -14,6 +15,11 @@ def mock_priorities_file_malformed(mocker): mocker.patch("builtins.open", mocker.mock_open(read_data="strain1 X\n")) +@pytest.fixture +def mock_run_shell_command(mocker): + mocker.patch("augur.filter.run_shell_command") + + class TestFilter: def test_read_vcf_compressed(self): seq_keep, all_seq = augur.filter.read_vcf( @@ -47,3 +53,53 @@ def test_read_priority_scores_malformed(self, mock_priorities_file_malformed): def test_read_priority_scores_does_not_exist(self): with pytest.raises(FileNotFoundError): augur.filter.read_priority_scores("/does/not/exist.txt") + + def test_write_vcf_compressed_input(self, mock_run_shell_command): + augur.filter.write_vcf( + "tests/builds/tb/data/lee_2015.vcf.gz", "output_file.vcf.gz", [] + ) + + augur.filter.run_shell_command.assert_called_once_with( + "vcftools --gzvcf tests/builds/tb/data/lee_2015.vcf.gz --recode --stdout | gzip -c > output_file.vcf.gz", + raise_errors=True, + ) + + def test_write_vcf_uncompressed_input(self, mock_run_shell_command): + augur.filter.write_vcf( + "tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", [] + ) + + augur.filter.run_shell_command.assert_called_once_with( + "vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz", + raise_errors=True, + ) + + def test_write_vcf_compressed_output(self, mock_run_shell_command): + augur.filter.write_vcf( + "tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", [] + ) + + augur.filter.run_shell_command.assert_called_once_with( + "vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz", + raise_errors=True, + ) + + def test_write_vcf_uncompressed_output(self, mock_run_shell_command): + augur.filter.write_vcf( + "tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", [] + ) + + augur.filter.run_shell_command.assert_called_once_with( + "vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf", + raise_errors=True, + ) + + def test_write_vcf_dropped_samples(self, mock_run_shell_command): + augur.filter.write_vcf( + "tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", ["x", "y", "z"] + ) + + augur.filter.run_shell_command.assert_called_once_with( + "vcftools --remove-indv x --remove-indv y --remove-indv z --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf", + raise_errors=True, + )