-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Translate original filter pytests to work with an in-memory database. - Move VCF-related tests to a better home in test_utils. - Split filter tests into separate files: - file loading - filtering - TODO: subsampling
- Loading branch information
Showing
5 changed files
with
486 additions
and
444 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,267 +1,51 @@ | ||
import argparse | ||
import numpy as np | ||
import random | ||
import shlex | ||
|
||
import pytest | ||
|
||
from Bio import SeqIO | ||
from Bio.Seq import Seq | ||
from Bio.SeqRecord import SeqRecord | ||
import sqlite3 | ||
|
||
import augur.filter | ||
from augur.utils import read_metadata | ||
from augur.filter_support.db.sqlite import FilterSQLite | ||
|
||
@pytest.fixture | ||
def argparser(): | ||
|
||
def parse_args(args:str): | ||
parser = argparse.ArgumentParser() | ||
augur.filter.register_arguments(parser) | ||
def parse(args): | ||
return parser.parse_args(shlex.split(args)) | ||
return parse | ||
|
||
@pytest.fixture | ||
def sequences(): | ||
def random_seq(k): | ||
return "".join(random.choices(("A","T","G","C"), k=k)) | ||
return { | ||
"SEQ_1": SeqRecord(Seq(random_seq(10)), id="SEQ_1"), | ||
"SEQ_2": SeqRecord(Seq(random_seq(10)), id="SEQ_2"), | ||
"SEQ_3": SeqRecord(Seq(random_seq(10)), id="SEQ_3"), | ||
} | ||
|
||
@pytest.fixture | ||
def fasta_fn(tmpdir, sequences): | ||
fn = str(tmpdir / "sequences.fasta") | ||
SeqIO.write(sequences.values(), fn, "fasta") | ||
return fn | ||
|
||
def write_metadata(tmpdir, metadata): | ||
fn = str(tmpdir / "metadata.tsv") | ||
with open(fn, "w") as fh: | ||
fh.write("\n".join(("\t".join(md) for md in metadata))) | ||
return fn | ||
|
||
@pytest.fixture | ||
def mock_priorities_file_valid(mocker): | ||
mocker.patch( | ||
"builtins.open", mocker.mock_open(read_data="strain1 5\nstrain2 6\nstrain3 8\n") | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_priorities_file_malformed(mocker): | ||
mocker.patch("builtins.open", mocker.mock_open(read_data="strain1 X\n")) | ||
|
||
|
||
@pytest.fixture | ||
def mock_run_shell_command(mocker): | ||
mocker.patch("augur.filter.run_shell_command") | ||
|
||
|
||
@pytest.fixture | ||
def mock_priorities_file_valid_with_spaces_and_tabs(mocker): | ||
mocker.patch( | ||
"builtins.open", mocker.mock_open(read_data="strain 1\t5\nstrain 2\t6\nstrain 3\t8\n") | ||
) | ||
|
||
class TestFilter: | ||
def test_read_vcf_compressed(self): | ||
seq_keep, all_seq = augur.filter.read_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf.gz" | ||
) | ||
|
||
assert len(seq_keep) == 150 | ||
assert seq_keep[149] == "G22733" | ||
assert seq_keep == all_seq | ||
return parser.parse_args(shlex.split(args)) | ||
|
||
def test_read_vcf_uncompressed(self): | ||
seq_keep, all_seq = augur.filter.read_vcf("tests/builds/tb/data/lee_2015.vcf") | ||
|
||
assert len(seq_keep) == 150 | ||
assert seq_keep[149] == "G22733" | ||
assert seq_keep == all_seq | ||
def write_file(tmpdir, filename:str, content:str): | ||
filepath = str(tmpdir / filename) | ||
with open(filepath, "w") as handle: | ||
handle.write(content) | ||
return filepath | ||
|
||
def test_read_priority_scores_valid(self, mock_priorities_file_valid): | ||
# builtins.open is stubbed, but we need a valid file to satisfy the existence check | ||
priorities = augur.filter.read_priority_scores( | ||
"tests/builds/tb/data/lee_2015.vcf" | ||
) | ||
|
||
assert priorities == {"strain1": 5, "strain2": 6, "strain3": 8} | ||
assert priorities["strain1"] == 5 | ||
assert priorities["strain42"] == -np.inf, "Default priority is negative infinity for unlisted sequences" | ||
|
||
def test_read_priority_scores_malformed(self, mock_priorities_file_malformed): | ||
with pytest.raises(ValueError): | ||
# builtins.open is stubbed, but we need a valid file to satisfy the existence check | ||
augur.filter.read_priority_scores("tests/builds/tb/data/lee_2015.vcf") | ||
|
||
def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_file_valid_with_spaces_and_tabs): | ||
# builtins.open is stubbed, but we need a valid file to satisfy the existence check | ||
priorities = augur.filter.read_priority_scores( | ||
"tests/builds/tb/data/lee_2015.vcf" | ||
) | ||
|
||
assert priorities == {"strain 1": 5, "strain 2": 6, "strain 3": 8} | ||
|
||
def test_read_priority_scores_does_not_exist(self): | ||
with pytest.raises(FileNotFoundError): | ||
augur.filter.read_priority_scores("/does/not/exist.txt") | ||
|
||
def test_write_vcf_compressed_input(self, mock_run_shell_command): | ||
augur.filter.write_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf.gz", "output_file.vcf.gz", [] | ||
) | ||
|
||
augur.filter.run_shell_command.assert_called_once_with( | ||
"vcftools --gzvcf tests/builds/tb/data/lee_2015.vcf.gz --recode --stdout | gzip -c > output_file.vcf.gz", | ||
raise_errors=True, | ||
) | ||
|
||
def test_write_vcf_uncompressed_input(self, mock_run_shell_command): | ||
augur.filter.write_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", [] | ||
) | ||
|
||
augur.filter.run_shell_command.assert_called_once_with( | ||
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz", | ||
raise_errors=True, | ||
) | ||
|
||
def test_write_vcf_compressed_output(self, mock_run_shell_command): | ||
augur.filter.write_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf.gz", [] | ||
) | ||
|
||
augur.filter.run_shell_command.assert_called_once_with( | ||
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout | gzip -c > output_file.vcf.gz", | ||
raise_errors=True, | ||
) | ||
|
||
def test_write_vcf_uncompressed_output(self, mock_run_shell_command): | ||
augur.filter.write_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", [] | ||
) | ||
|
||
augur.filter.run_shell_command.assert_called_once_with( | ||
"vcftools --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf", | ||
raise_errors=True, | ||
) | ||
|
||
def test_write_vcf_dropped_samples(self, mock_run_shell_command): | ||
augur.filter.write_vcf( | ||
"tests/builds/tb/data/lee_2015.vcf", "output_file.vcf", ["x", "y", "z"] | ||
) | ||
|
||
augur.filter.run_shell_command.assert_called_once_with( | ||
"vcftools --remove-indv x --remove-indv y --remove-indv z --vcf tests/builds/tb/data/lee_2015.vcf --recode --stdout > output_file.vcf", | ||
raise_errors=True, | ||
) | ||
def write_metadata(tmpdir, metadata): | ||
content = "\n".join(("\t".join(md) for md in metadata)) | ||
return write_file(tmpdir, "metadata.tsv", content) | ||
|
||
def test_filter_on_query_good(self, tmpdir, sequences): | ||
"""Basic filter_on_query test""" | ||
meta_fn = write_metadata(tmpdir, (("strain","location","quality"), | ||
("SEQ_1","colorado","good"), | ||
("SEQ_2","colorado","bad"), | ||
("SEQ_3","nevada","good"))) | ||
metadata, columns = read_metadata(meta_fn, as_data_frame=True) | ||
filtered = augur.filter.filter_by_query(metadata, 'quality=="good"') | ||
assert sorted(filtered) == ["SEQ_1", "SEQ_3"] | ||
|
||
def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser): | ||
"""Test that filter --query works as expected""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
meta_fn = write_metadata(tmpdir, (("strain","location","quality"), | ||
("SEQ_1","colorado","good"), | ||
("SEQ_2","colorado","bad"), | ||
("SEQ_3","nevada","good"))) | ||
args = argparser('-s %s --metadata %s -o %s --query "location==\'colorado\'"' | ||
% (fasta_fn, meta_fn, out_fn)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_2"] | ||
def get_filter_obj_run(args:argparse.Namespace): | ||
"""Returns a filter object connected to an in-memory database with run() invoked.""" | ||
obj = FilterSQLite(':memory:') | ||
obj.set_args(args) | ||
# keep intermediate tables to validate contents | ||
obj.run(cleanup=False) | ||
return obj | ||
|
||
def test_filter_run_with_query_and_include(self, tmpdir, fasta_fn, argparser): | ||
"""Test that --include still works with filtering on query""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
meta_fn = write_metadata(tmpdir, (("strain","location","quality"), | ||
("SEQ_1","colorado","good"), | ||
("SEQ_2","colorado","bad"), | ||
("SEQ_3","nevada","good"))) | ||
include_fn = str(tmpdir / "include") | ||
open(include_fn, "w").write("SEQ_3") | ||
args = argparser('-s %s --metadata %s -o %s --query "quality==\'good\' & location==\'colorado\'" --include %s' | ||
% (fasta_fn, meta_fn, out_fn, include_fn)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_3"] | ||
|
||
def test_filter_run_with_query_and_include_where(self, tmpdir, fasta_fn, argparser): | ||
"""Test that --include_where still works with filtering on query""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
meta_fn = write_metadata(tmpdir, (("strain","location","quality"), | ||
("SEQ_1","colorado","good"), | ||
("SEQ_2","colorado","bad"), | ||
("SEQ_3","nevada","good"))) | ||
args = argparser('-s %s --metadata %s -o %s --query "quality==\'good\' & location==\'colorado\'" --include-where "location=nevada"' | ||
% (fasta_fn, meta_fn, out_fn)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_3"] | ||
def get_valid_args(data, tmpdir): | ||
"""Returns an argparse.Namespace with metadata and output_strains""" | ||
meta_fn = write_metadata(tmpdir, data) | ||
return parse_args(f'--metadata {meta_fn} --output-strains {tmpdir / "strains.txt"}') | ||
|
||
def test_filter_run_min_date(self, tmpdir, fasta_fn, argparser): | ||
"""Test that filter --min-date is inclusive""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
min_date = "2020-02-26" | ||
meta_fn = write_metadata(tmpdir, (("strain","date"), | ||
("SEQ_1","2020-02-XX"), | ||
("SEQ_2","2020-02-26"), | ||
("SEQ_3","2020-02-25"))) | ||
args = argparser('-s %s --metadata %s -o %s --min-date %s' | ||
% (fasta_fn, meta_fn, out_fn, min_date)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_2"] | ||
|
||
def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser): | ||
"""Test that filter --max-date is inclusive""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
max_date = "2020-03-01" | ||
meta_fn = write_metadata(tmpdir, (("strain","date"), | ||
("SEQ_1","2020-03-XX"), | ||
("SEQ_2","2020-03-01"), | ||
("SEQ_3","2020-03-02"))) | ||
args = argparser('-s %s --metadata %s -o %s --max-date %s' | ||
% (fasta_fn, meta_fn, out_fn, max_date)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_2"] | ||
def query_fetchall(filter_obj:FilterSQLite, query:str): | ||
filter_obj.cur.execute(query) | ||
return filter_obj.cur.fetchall() | ||
|
||
def test_filter_incomplete_year(self, tmpdir, fasta_fn, argparser): | ||
"""Test that 2020 is evaluated as 2020-XX-XX""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
min_date = "2020-02-01" | ||
meta_fn = write_metadata(tmpdir, (("strain","date"), | ||
("SEQ_1","2020.0"), | ||
("SEQ_2","2020"), | ||
("SEQ_3","2020-XX-XX"))) | ||
args = argparser('-s %s --metadata %s -o %s --min-date %s' | ||
% (fasta_fn, meta_fn, out_fn, min_date)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_2", "SEQ_3"] | ||
|
||
def test_filter_date_formats(self, tmpdir, fasta_fn, argparser): | ||
"""Test that 2020.0, 2020, and 2020-XX-XX all pass --min-date 2019""" | ||
out_fn = str(tmpdir / "out.fasta") | ||
min_date = "2019" | ||
meta_fn = write_metadata(tmpdir, (("strain","date"), | ||
("SEQ_1","2020.0"), | ||
("SEQ_2","2020"), | ||
("SEQ_3","2020-XX-XX"))) | ||
args = argparser('-s %s --metadata %s -o %s --min-date %s' | ||
% (fasta_fn, meta_fn, out_fn, min_date)) | ||
augur.filter.run(args) | ||
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) | ||
assert list(output.keys()) == ["SEQ_1", "SEQ_2", "SEQ_3"] | ||
def query_fetchall_dict(filter_obj:FilterSQLite, query:str): | ||
filter_obj.connection.row_factory = sqlite3.Row | ||
cur = filter_obj.connection.cursor() | ||
cur.execute(query) | ||
return [dict(row) for row in cur.fetchall()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
|
||
from augur.filter_support.db.sqlite import ( | ||
METADATA_TABLE_NAME, | ||
PRIORITIES_TABLE_NAME, | ||
) | ||
|
||
from test_filter import write_file | ||
from tests.test_filter import get_filter_obj_run, get_valid_args, query_fetchall | ||
|
||
|
||
def get_filter_obj_with_priority_loaded(tmpdir, content:str): | ||
priorities_fn = write_file(tmpdir, "priorities.txt", content) | ||
# metadata is a required arg but we don't need it | ||
data = [("strain","location","quality"), | ||
("SEQ_1","colorado","good")] | ||
args = get_valid_args(data, tmpdir) | ||
args.priority = priorities_fn | ||
return get_filter_obj_run(args) | ||
|
||
|
||
class TestDataLoading: | ||
def test_load_metadata(self, tmpdir): | ||
"""Load a metadata file.""" | ||
data = [("strain","location","quality"), | ||
("SEQ_1","colorado","good"), | ||
("SEQ_2","colorado","bad"), | ||
("SEQ_3","nevada","good")] | ||
args = get_valid_args(data, tmpdir) | ||
filter_obj = get_filter_obj_run(args) | ||
results = query_fetchall(filter_obj, f"SELECT * FROM {METADATA_TABLE_NAME}") | ||
assert [row[1:] for row in results] == data[1:] | ||
|
||
def test_load_priority_scores_valid(self, tmpdir): | ||
"""Load a priority score file.""" | ||
content = "strain1\t5\nstrain2\t6\nstrain3\t8\n" | ||
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content) | ||
filter_obj.db_load_priorities_table() | ||
results = query_fetchall(filter_obj, f"SELECT * FROM {PRIORITIES_TABLE_NAME}") | ||
assert results == [(0, "strain1", 5.0), (1, "strain2", 6.0), (2, "strain3", 8.0)] | ||
|
||
@pytest.mark.skip(reason="this isn't trivial with SQLite's flexible typing rules") | ||
def test_load_priority_scores_malformed(self, tmpdir): | ||
"""Attempt to load a priority score file with non-float in priority column raises a ValueError.""" | ||
content = "strain1 X\n" | ||
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content) | ||
with pytest.raises(ValueError) as e_info: | ||
filter_obj.db_load_priorities_table() | ||
assert str(e_info.value) == f"Failed to parse priority file {filter_obj.args.priority}." | ||
|
||
def test_load_priority_scores_valid_with_spaces_and_tabs(self, tmpdir): | ||
"""Load a priority score file with spaces in strain names.""" | ||
content = "strain 1\t5\nstrain 2\t6\nstrain 3\t8\n" | ||
filter_obj = get_filter_obj_with_priority_loaded(tmpdir, content) | ||
filter_obj.db_load_priorities_table() | ||
results = query_fetchall(filter_obj, f"SELECT * FROM {PRIORITIES_TABLE_NAME}") | ||
assert results == [(0, "strain 1", 5.0), (1, "strain 2", 6.0), (2, "strain 3", 8.0)] | ||
|
||
def test_load_priority_scores_does_not_exist(self, tmpdir): | ||
"""Attempt to load a non-existant priority score file raises a FileNotFoundError.""" | ||
invalid_priorities_fn = str(tmpdir / "does/not/exist.txt") | ||
# metadata is a required arg but we don't need it | ||
data = [("strain","location","quality"), | ||
("SEQ_1","colorado","good")] | ||
args = get_valid_args(data, tmpdir) | ||
args.priority = invalid_priorities_fn | ||
filter_obj = get_filter_obj_run(args) | ||
with pytest.raises(FileNotFoundError): | ||
filter_obj.db_load_priorities_table() |
Oops, something went wrong.