Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augur Mask: Add additional options from NCOV mask-alignment.py script #512

Merged
merged 17 commits into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions augur/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import MutableSeq

from .utils import run_shell_command, shquote, open_file, is_vcf

Expand Down Expand Up @@ -88,7 +89,7 @@ def mask_vcf(mask_sites, in_file, out_file, cleanup=True):
except OSError:
pass

def mask_fasta(mask_sites, in_file, out_file):
def mask_fasta(mask_sites, in_file, out_file, mask_from_beginning=0, mask_from_end=0):
"""Mask the provided site list from a FASTA file and write to a new file.

Masked sites are overwritten as "N"s.
Expand All @@ -101,6 +102,10 @@ def mask_fasta(mask_sites, in_file, out_file):
The path to the FASTA file you wish to mask.
out_file: str
The path to write the resulting FASTA to
mask_from_beginning: int
Number of sites to mask from the beginning of each sequence (default 0)
mask_from_end: int
Number of sites to mask from the end of each sequence (default 0)
"""
# Load alignment as FASTA generator to prevent loading the whole alignment
# into memory.
Expand All @@ -111,8 +116,15 @@ def mask_fasta(mask_sites, in_file, out_file):
with open_file(out_file, "w") as oh:
for record in alignment:
# Convert to a mutable sequence to enable masking with Ns.
sequence = record.seq.tomutable()
sequence_length = len(sequence)
sequence_length = len(record.seq)
beginning, end = mask_from_beginning, mask_from_end
if beginning + end > sequence_length:
beginning, end = sequence_length, 0
sequence = MutableSeq(
"N" * beginning +
str(record.seq)[beginning:-end or None] +
"N" * end
)
# Replace all excluded sites with Ns.
for site in mask_sites:
if site < sequence_length:
Expand All @@ -122,7 +134,10 @@ def mask_fasta(mask_sites, in_file, out_file):

def register_arguments(parser):
parser.add_argument('--sequences', '-s', required=True, help="sequences in VCF or FASTA format")
parser.add_argument('--mask', required=True, help="locations to be masked in BED file format")
parser.add_argument('--mask', dest="mask_file", required=False, help="locations to be masked in BED file format")
parser.add_argument('--mask-from-beginning', type=int, help="FASTA Only: Number of sites to mask from beginning")
parser.add_argument('--mask-from-end', type=int, help="FASTA Only: Number of sites to mask from end")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two arguments need to have default values of 0 or they will get passed as None to the downstream math in mask_fasta and produce this exception:

Removing masked sites from FASTA file.
Traceback (most recent call last):
  File "./bin/augur", line 24, in <module>
    exit( main() )
  File "./augur/__main__.py", line 10, in main
    return augur.run( argv[1:] )
  File "./augur/__init__.py", line 74, in run
    return args.__command__.run(args)
  File "./augur/mask.py", line 197, in run
    mask_from_end=args.mask_from_end)
  File "./augur/mask.py", line 121, in mask_fasta
    if beginning + end > sequence_length:
TypeError: unsupported operand type(s) for +: 'NoneType' and 'NoneType'

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crap, I missed that. Per your comment below, I've added a couple E2E tests that are exercising the full code path - should've done that in the first place, I've gotten burned by the "unit tests but no integration tests" bug before.

parser.add_argument("--mask-sites", nargs='+', type = int, help="list of sites to mask")
parser.add_argument('--output', '-o', help="output file")
parser.add_argument('--no-cleanup', dest="cleanup", action="store_false",
help="Leave intermediate files around. May be useful for debugging")
Expand All @@ -141,19 +156,28 @@ def run(args):
# Check files exist and are not empty
if not os.path.isfile(args.sequences):
print("ERROR: File {} does not exist!".format(args.sequences))
return 1
if not os.path.isfile(args.mask):
print("ERROR: File {} does not exist!".format(args.mask))
return 1
sys.exit(1)
if os.path.getsize(args.sequences) == 0:
print("ERROR: {} is empty. Please check how this file was produced. "
"Did an error occur in an earlier step?".format(args.sequences))
return 1
if os.path.getsize(args.mask) == 0:
print("ERROR: {} is an empty file.".format(args.mask))
return 1
sys.exit(1)
if args.mask_file:
if not os.path.isfile(args.mask_file):
print("ERROR: File {} does not exist!".format(args.mask_file))
sys.exit(1)
if os.path.getsize(args.mask_file) == 0:
print("ERROR: {} is an empty file.".format(args.mask_file))
sys.exit(1)
if not any((args.mask_file, args.mask_from_beginning, args.mask_from_end, args.mask_sites)):
print("No masking sites provided. Must include one of --mask, --mask-from-beginning, --mask-from-end, or --mask-sites")
sys.exit(1)

mask_sites = read_bed_file(args.mask)
mask_sites = set()
if args.mask_sites:
mask_sites.update(args.mask_sites)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See main review, but this is where mask sites need to be converted from 1-based to 0-based coordinates.

if args.mask_file:
mask_sites.update(read_bed_file(args.mask_file))
mask_sites = sorted(mask_sites)

# For both FASTA and VCF masking, we need a proper separate output file
if args.output is not None:
Expand All @@ -163,9 +187,14 @@ def run(args):
"masked_" + os.path.basename(args.sequences))

if is_vcf(args.sequences):
if args.mask_from_beginning or args.mask_from_end:
print("Cannot use --mask-from-beginning or --mask-from-end with VCF files!")
sys.exit(1)
mask_vcf(mask_sites, args.sequences, out_file, args.cleanup)
else:
mask_fasta(mask_sites, args.sequences, out_file)
mask_fasta(mask_sites, args.sequences, out_file,
mask_from_beginning=args.mask_from_beginning,
mask_from_end=args.mask_from_end)

if args.output is None:
copyfile(out_file, args.sequences)
Expand Down
120 changes: 105 additions & 15 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def bed_file(tmpdir):
fh.write(TEST_BED)
return bed_file

@pytest.fixture
def out_file(tmpdir):
out_file = str(tmpdir / "out")
open(out_file, "w").close()
return out_file

@pytest.fixture
def mp_context(monkeypatch):
#Have found explicit monkeypatch context-ing prevents stupid bugs
Expand Down Expand Up @@ -136,9 +142,8 @@ def test_shell(call, raise_errors=True):
in_file = vcf_file + ".gz"
mask.mask_vcf([1,5], in_file, in_file)

def test_mask_vcf_removes_matching_sites(self, tmpdir, vcf_file):
def test_mask_vcf_removes_matching_sites(self, vcf_file, out_file):
"""mask_vcf should remove the given sites from the VCF file"""
out_file = str(tmpdir / "output.vcf")
mask.mask_vcf([5,6], vcf_file, out_file)
with open(out_file) as after, open(vcf_file) as before:
assert len(after.readlines()) == len(before.readlines()) - 1, "Too many lines removed!"
Expand All @@ -156,9 +161,8 @@ def test_mask_vcf_cleanup_flag(self, vcf_file, mp_context):
mask.mask_vcf([], vcf_file, "", cleanup=False)
assert os.path.isfile(tmp_mask_file), "Temporary mask cleaned up as expected"

def test_mask_fasta_normal_case(self, tmpdir, fasta_file, sequences):
def test_mask_fasta_normal_case(self, fasta_file, out_file, sequences):
"""mask_fasta normal case - all sites in sequences"""
out_file = str(tmpdir / "output.fasta")
mask_sites = [5,10]
mask.mask_fasta([5,10], fasta_file, out_file)
output = SeqIO.parse(out_file, "fasta")
Expand All @@ -170,9 +174,8 @@ def test_mask_fasta_normal_case(self, tmpdir, fasta_file, sequences):
else:
assert site == "N", "Not all sites modified correctly!"

def test_mask_fasta_out_of_index(self, tmpdir, fasta_file, sequences):
def test_mask_fasta_out_of_index(self, out_file, fasta_file, sequences):
"""mask_fasta provided a list of indexes past the length of the sequences"""
out_file = str(tmpdir / "output.fasta")
max_length = max(len(record.seq) for record in sequences.values())
mask.mask_fasta([5, max_length, max_length+5], fasta_file, out_file)
output = SeqIO.parse(out_file, "fasta")
Expand All @@ -183,6 +186,62 @@ def test_mask_fasta_out_of_index(self, tmpdir, fasta_file, sequences):
if idx != 5:
assert site == original[idx], "Incorrect sites modified!"

def test_mask_fasta_from_beginning(self, out_file, fasta_file, sequences):
mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=3)
output = SeqIO.parse(out_file, "fasta")
for seq in output:
original = sequences[seq.id]
assert seq.seq[:3] == "NNN"
assert seq.seq[3:] == original.seq[3:]

def test_mask_fasta_from_end(self, out_file, fasta_file, sequences):
mask.mask_fasta([], fasta_file, out_file, mask_from_end=3)
output = SeqIO.parse(out_file, "fasta")
for seq in output:
original = sequences[seq.id]
assert seq.seq[-3:] == "NNN"
assert seq.seq[:-3] == original.seq[:-3]

def test_mask_fasta_from_beginning_and_end(self, out_file, fasta_file, sequences):
mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=2, mask_from_end=3)
output = SeqIO.parse(out_file, "fasta")
for seq in output:
original = sequences[seq.id]
assert seq.seq[:2] == "NN"
assert seq.seq[-3:] == "NNN"
assert seq.seq[2:-3] == original.seq[2:-3]

@pytest.mark.parametrize("beginning,end", ((1000,0), (0,1000),(1000,1000)))
def test_mask_fasta_from_beginning_and_end_too_long(self, fasta_file, out_file, beginning, end):
mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=beginning, mask_from_end=end)
output = SeqIO.parse(out_file, "fasta")
for record in output:
assert record.seq == "N" * len(record.seq)

def test_run_handle_missing_sequence_file(self, vcf_file, argparser):
os.remove(vcf_file)
args = argparser("-s %s" % vcf_file)
with pytest.raises(SystemExit):
mask.run(args)

def test_run_handle_empty_sequence_file(self, vcf_file, argparser):
open(vcf_file,"w").close()
args = argparser("-s %s --mask-sites 1" % vcf_file)
with pytest.raises(SystemExit):
mask.run(args)

def test_run_handle_missing_mask_file(self, vcf_file, bed_file, argparser):
os.remove(bed_file)
args = argparser("-s %s --mask %s" % (vcf_file, bed_file))
with pytest.raises(SystemExit):
mask.run(args)

def test_run_handle_empty_mask_file(self, vcf_file, bed_file, argparser):
open(bed_file, "w").close()
args = argparser("-s %s --mask %s" % (vcf_file, bed_file))
with pytest.raises(SystemExit):
mask.run(args)

def test_run_recognize_vcf(self, bed_file, vcf_file, argparser, mp_context):
"""Ensure we're handling vcf files correctly"""
args = argparser("--mask=%s -s %s --no-cleanup" % (bed_file, vcf_file))
Expand All @@ -206,7 +265,7 @@ def fail(*args, **kwargs):
def test_run_handle_missing_outfile(self, bed_file, fasta_file, argparser, mp_context):
args = argparser("--mask=%s -s %s" % (bed_file, fasta_file))
expected_outfile = os.path.join(os.path.dirname(fasta_file), "masked_" + os.path.basename(fasta_file))
def check_outfile(mask_sites, in_file, out_file):
def check_outfile(mask_sites, in_file, out_file, **kwargs):
assert out_file == expected_outfile
with open(out_file, "w") as fh:
fh.write("test_string")
Expand All @@ -215,7 +274,7 @@ def check_outfile(mask_sites, in_file, out_file):
with open(fasta_file) as fh:
assert fh.read() == "test_string"

def test_run_respect_no_cleanup(self, bed_file, tmpdir, vcf_file, argparser, mp_context):
def test_run_respect_no_cleanup(self, bed_file, vcf_file, argparser, mp_context):
out_file = os.path.join(os.path.dirname(vcf_file), "masked_" + os.path.basename(vcf_file))
def make_outfile(mask_sites, in_file, out_file, cleanup=True):
assert cleanup == False
Expand All @@ -225,15 +284,46 @@ def make_outfile(mask_sites, in_file, out_file, cleanup=True):
mask.run(args)
assert os.path.exists(out_file), "Output file incorrectly deleted"

def test_run_normal_case(self, bed_file, vcf_file, tmpdir, argparser, mp_context):
test_outfile = str(tmpdir / "out")
def check_args(mask_sites, in_file, out_file, cleanup):
def test_run_normal_case(self, bed_file, vcf_file, out_file, argparser, mp_context):
def check_args(mask_sites, in_file, _out_file, cleanup):
assert mask_sites == TEST_BED_SEQUENCE, "Wrong mask sites provided"
assert in_file == vcf_file, "Incorrect input file provided"
assert out_file == test_outfile, "Incorrect output file provided"
assert _out_file == out_file, "Incorrect output file provided"
assert cleanup is True, "Cleanup erroneously passed in as False"
open(out_file, "w").close() # want to test we don't delete output.
mp_context.setattr(mask, "mask_vcf", check_args)
args = argparser("--mask=%s --sequences=%s --output=%s" %(bed_file, vcf_file, test_outfile))
args = argparser("--mask=%s --sequences=%s --output=%s" %(bed_file, vcf_file, out_file))
mask.run(args)
assert os.path.exists(out_file), "Output file incorrectly deleted"

def test_run_with_mask_sites(self, vcf_file, out_file, argparser, mp_context):
args = argparser("--mask-sites 2 8 -s %s -o %s" % (vcf_file, out_file))
def check_mask_sites(mask_sites, *args, **kwargs):
assert mask_sites == [2,8]
mp_context.setattr(mask, "mask_vcf", check_mask_sites)
mask.run(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would be more helpful if it ran mask_vcf and confirmed that the requested sites were no longer present in the output file. I actually love the way you've setup the argparser fixture here such that we can essentially run end-to-end tests from the unit testing framework. I know it is a little more computationally expensive, but running the whole program this way can catch more errors.

We also need a corresponding test for masking specific sites with a FASTA input. This test will fail right now because mask_sites values are not converted to 0-based coordinates (and because of the mask from beginning/end None issue mentioned earlier).

def test_run_with_mask_sites_for_fasta(self, fasta_file, out_file, argparser):
    # 1-based coordinates from the user should translate to 0-based masking of the FASTA sequence.
    args = argparser("--mask-sites 1 -s %s -o %s" % (fasta_file, out_file))
    mask.run(args)
    output = SeqIO.parse(out_file, "fasta")
    for record in output:
        # Site 1 in the sequence is the 0-indexed site in the sequence record.
        assert record.seq[0] == "N"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually running fully through the functions is a good call. I've added E2E tests for handling both VCF and FASTA masking. The test files are all small enough that it's not really that expensive to do the full test, and the test files are created & written specifically for each test (that's what the pytest tmpdir fixture is giving us), so there shouldn't be any race conditions here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is so cool! Thank you!


def test_run_with_mask_sites_and_mask_file(self, vcf_file, out_file, bed_file, argparser, mp_context):
args = argparser("--mask-sites 20 21 --mask %s -s %s -o %s" % (bed_file, vcf_file, out_file))
def check_mask_sites(mask_sites, *args, **kwargs):
assert mask_sites == sorted(set(TEST_BED_SEQUENCE + [20,21]))
mp_context.setattr(mask, "mask_vcf", check_mask_sites)
mask.run(args)

def test_run_requires_some_masking(self, vcf_file, argparser):
args = argparser("-s %s" % vcf_file)
with pytest.raises(SystemExit) as err:
mask.run(args)

@pytest.mark.parametrize("op", ("beginning", "end"))
def test_run_vcf_cannot_mask_beginning_or_end(self, vcf_file, argparser, op):
args = argparser("-s %s --mask-from-%s 2" % (vcf_file, op))
with pytest.raises(SystemExit) as err:
mask.run(args)

def test_run_fasta_mask_from_beginning_or_end(self, fasta_file, out_file, argparser, mp_context):
args = argparser("-s %s -o %s --mask-from-beginning 2 --mask-from-end 3" % (fasta_file, out_file))
def check_mask_from(*args, mask_from_beginning=0, mask_from_end=0):
assert mask_from_beginning == 2
assert mask_from_end == 3
mp_context.setattr(mask, "mask_fasta", check_mask_from)
mask.run(args)
assert os.path.exists(test_outfile), "Output file incorrectly deleted"