diff --git a/augur/mask.py b/augur/mask.py index 40da76632..13f40ad9e 100644 --- a/augur/mask.py +++ b/augur/mask.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd from Bio import SeqIO +from Bio.Seq import MutableSeq from .utils import run_shell_command, shquote, open_file, is_vcf @@ -32,7 +33,7 @@ def read_bed_file(mask_file): bed = pd.read_csv(mask_file, sep='\t', header=None, usecols=[1,2]) for idx, row in bed.iterrows(): try: - sites_to_mask.extend(list(range(int(row[1]), int(row[2])+1))) + sites_to_mask.extend(range(int(row[1]), int(row[2]))) except ValueError as err: # Skip unparseable lines, including header lines. print("Could not read line %d of BED file %s: %s. Continuing." % (idx, mask_file, err)) @@ -66,7 +67,8 @@ def mask_vcf(mask_sites, in_file, out_file, cleanup=True): "Please check the file is valid VCF format.") sys.exit(1) - exclude = [chrom_name + "\t" + str(pos) for pos in mask_sites] + # mask_sites is zero-indexed, VCFTools expects 1-indexed. + exclude = [chrom_name + "\t" + str(pos + 1) for pos in mask_sites] temp_mask_file = in_file + "_maskTemp" with open_file(temp_mask_file, 'w') as fh: fh.write("\n".join(exclude)) @@ -88,7 +90,7 @@ def mask_vcf(mask_sites, in_file, out_file, cleanup=True): except OSError: pass -def mask_fasta(mask_sites, in_file, out_file): +def mask_fasta(mask_sites, in_file, out_file, mask_from_beginning=0, mask_from_end=0): """Mask the provided site list from a FASTA file and write to a new file. Masked sites are overwritten as "N"s. @@ -101,6 +103,10 @@ def mask_fasta(mask_sites, in_file, out_file): The path to the FASTA file you wish to mask. out_file: str The path to write the resulting FASTA to + mask_from_beginning: int + Number of sites to mask from the beginning of each sequence (default 0) + mask_from_end: int + Number of sites to mask from the end of each sequence (default 0) """ # Load alignment as FASTA generator to prevent loading the whole alignment # into memory. @@ -111,8 +117,15 @@ def mask_fasta(mask_sites, in_file, out_file): with open_file(out_file, "w") as oh: for record in alignment: # Convert to a mutable sequence to enable masking with Ns. - sequence = record.seq.tomutable() - sequence_length = len(sequence) + sequence_length = len(record.seq) + beginning, end = mask_from_beginning, mask_from_end + if beginning + end > sequence_length: + beginning, end = sequence_length, 0 + sequence = MutableSeq( + "N" * beginning + + str(record.seq)[beginning:-end or None] + + "N" * end + ) # Replace all excluded sites with Ns. for site in mask_sites: if site < sequence_length: @@ -122,7 +135,10 @@ def mask_fasta(mask_sites, in_file, out_file): def register_arguments(parser): parser.add_argument('--sequences', '-s', required=True, help="sequences in VCF or FASTA format") - parser.add_argument('--mask', required=True, help="locations to be masked in BED file format") + parser.add_argument('--mask', dest="mask_file", required=False, help="locations to be masked in BED file format") + parser.add_argument('--mask-from-beginning', type=int, default=0, help="FASTA Only: Number of sites to mask from beginning") + parser.add_argument('--mask-from-end', type=int, default=0, help="FASTA Only: Number of sites to mask from end") + parser.add_argument("--mask-sites", nargs='+', type = int, help="1-indexed list of sites to mask") parser.add_argument('--output', '-o', help="output file") parser.add_argument('--no-cleanup', dest="cleanup", action="store_false", help="Leave intermediate files around. May be useful for debugging") @@ -141,19 +157,29 @@ def run(args): # Check files exist and are not empty if not os.path.isfile(args.sequences): print("ERROR: File {} does not exist!".format(args.sequences)) - return 1 - if not os.path.isfile(args.mask): - print("ERROR: File {} does not exist!".format(args.mask)) - return 1 + sys.exit(1) if os.path.getsize(args.sequences) == 0: print("ERROR: {} is empty. Please check how this file was produced. " "Did an error occur in an earlier step?".format(args.sequences)) - return 1 - if os.path.getsize(args.mask) == 0: - print("ERROR: {} is an empty file.".format(args.mask)) - return 1 + sys.exit(1) + if args.mask_file: + if not os.path.isfile(args.mask_file): + print("ERROR: File {} does not exist!".format(args.mask_file)) + sys.exit(1) + if os.path.getsize(args.mask_file) == 0: + print("ERROR: {} is an empty file.".format(args.mask_file)) + sys.exit(1) + if not any((args.mask_file, args.mask_from_beginning, args.mask_from_end, args.mask_sites)): + print("No masking sites provided. Must include one of --mask, --mask-from-beginning, --mask-from-end, or --mask-sites") + sys.exit(1) - mask_sites = read_bed_file(args.mask) + mask_sites = set() + if args.mask_sites: + # Mask sites passed in as 1-indexed + mask_sites.update(site - 1 for site in args.mask_sites) + if args.mask_file: + mask_sites.update(read_bed_file(args.mask_file)) + mask_sites = sorted(mask_sites) # For both FASTA and VCF masking, we need a proper separate output file if args.output is not None: @@ -163,9 +189,14 @@ def run(args): "masked_" + os.path.basename(args.sequences)) if is_vcf(args.sequences): + if args.mask_from_beginning or args.mask_from_end: + print("Cannot use --mask-from-beginning or --mask-from-end with VCF files!") + sys.exit(1) mask_vcf(mask_sites, args.sequences, out_file, args.cleanup) else: - mask_fasta(mask_sites, args.sequences, out_file) + mask_fasta(mask_sites, args.sequences, out_file, + mask_from_beginning=args.mask_from_beginning, + mask_from_end=args.mask_from_end) if args.output is None: copyfile(out_file, args.sequences) diff --git a/tests/test_mask.py b/tests/test_mask.py index 8ca7239bf..7fbd7cffb 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -36,6 +36,7 @@ def fasta_file(tmpdir, sequences): SEQ 3 . C T . . SEQ 5 . C T . . SEQ 8 . A C . . +SEQ 13 . A C . . """ @pytest.fixture @@ -45,11 +46,11 @@ def vcf_file(tmpdir): fh.write(TEST_VCF) return vcf_file -TEST_BED_SEQUENCE = [1,2,4,6,7,8,9,10] +TEST_BED_SEQUENCE = [1,4,6,7,8,9] # IF YOU UPDATE ONE OF THESE, UPDATE THE OTHER. TEST_BED="""\ SEQ1 1 2 IG18_Rv0018c-Rv0019c -SEQ1 4 4 IG18_Rv0018c-Rv0019c +SEQ1 4 5 IG18_Rv0018c-Rv0019c SEQ1 6 8 IG18_Rv0018c-Rv0019c SEQ1 7 10 IG18_Rv0018c-Rv0019c """ @@ -61,6 +62,12 @@ def bed_file(tmpdir): fh.write(TEST_BED) return bed_file +@pytest.fixture +def out_file(tmpdir): + out_file = str(tmpdir / "out") + open(out_file, "w").close() + return out_file + @pytest.fixture def mp_context(monkeypatch): #Have found explicit monkeypatch context-ing prevents stupid bugs @@ -94,7 +101,7 @@ def test_read_bed_file_with_header(self, bed_file): """read_bed_file should ignore header rows if they exist""" with open(bed_file, "w") as fh: fh.write("CHROM\tSTART\tEND\n") - fh.write("SEQ\t5\t6") + fh.write("SEQ\t5\t7") assert mask.read_bed_file(bed_file) == [5,6] def test_read_bed_file(self, bed_file): @@ -112,7 +119,7 @@ def test_mask_vcf_bails_on_no_chrom(self, tmpdir): mask.mask_vcf([], bad_vcf, "") def test_mask_vcf_creates_maskfile(self, vcf_file, mp_context): - """mask_vcf should create and use a mask file from the given list of sites""" + """mask_vcf should create a 1-indexed mask file from the 0-indexed list of sites""" mask_file = vcf_file + "_maskTemp" def shell_has_maskfile(call, **kwargs): assert mask_file in call @@ -121,7 +128,7 @@ def shell_has_maskfile(call, **kwargs): mask.mask_vcf([1,5], vcf_file, vcf_file, cleanup=False) assert os.path.isfile(mask_file), "Mask file was not written!" with open(mask_file) as fh: - assert fh.read() == "SEQ 1\nSEQ 5", "Incorrect mask file written!" + assert fh.read() == "SEQ 2\nSEQ 6", "Incorrect mask file written!" def test_mask_vcf_handles_gz(self, vcf_file, mp_context): """mask_vcf should recognize when the in or out files are .gz and call out accordingly""" @@ -136,10 +143,9 @@ def test_shell(call, raise_errors=True): in_file = vcf_file + ".gz" mask.mask_vcf([1,5], in_file, in_file) - def test_mask_vcf_removes_matching_sites(self, tmpdir, vcf_file): + def test_mask_vcf_removes_matching_sites(self, vcf_file, out_file): """mask_vcf should remove the given sites from the VCF file""" - out_file = str(tmpdir / "output.vcf") - mask.mask_vcf([5,6], vcf_file, out_file) + mask.mask_vcf([4,5], vcf_file, out_file) with open(out_file) as after, open(vcf_file) as before: assert len(after.readlines()) == len(before.readlines()) - 1, "Too many lines removed!" assert "SEQ\t5" not in after.read(), "Correct sites not removed!" @@ -156,9 +162,8 @@ def test_mask_vcf_cleanup_flag(self, vcf_file, mp_context): mask.mask_vcf([], vcf_file, "", cleanup=False) assert os.path.isfile(tmp_mask_file), "Temporary mask cleaned up as expected" - def test_mask_fasta_normal_case(self, tmpdir, fasta_file, sequences): + def test_mask_fasta_normal_case(self, fasta_file, out_file, sequences): """mask_fasta normal case - all sites in sequences""" - out_file = str(tmpdir / "output.fasta") mask_sites = [5,10] mask.mask_fasta([5,10], fasta_file, out_file) output = SeqIO.parse(out_file, "fasta") @@ -170,9 +175,8 @@ def test_mask_fasta_normal_case(self, tmpdir, fasta_file, sequences): else: assert site == "N", "Not all sites modified correctly!" - def test_mask_fasta_out_of_index(self, tmpdir, fasta_file, sequences): + def test_mask_fasta_out_of_index(self, out_file, fasta_file, sequences): """mask_fasta provided a list of indexes past the length of the sequences""" - out_file = str(tmpdir / "output.fasta") max_length = max(len(record.seq) for record in sequences.values()) mask.mask_fasta([5, max_length, max_length+5], fasta_file, out_file) output = SeqIO.parse(out_file, "fasta") @@ -183,6 +187,62 @@ def test_mask_fasta_out_of_index(self, tmpdir, fasta_file, sequences): if idx != 5: assert site == original[idx], "Incorrect sites modified!" + def test_mask_fasta_from_beginning(self, out_file, fasta_file, sequences): + mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=3) + output = SeqIO.parse(out_file, "fasta") + for seq in output: + original = sequences[seq.id] + assert seq.seq[:3] == "NNN" + assert seq.seq[3:] == original.seq[3:] + + def test_mask_fasta_from_end(self, out_file, fasta_file, sequences): + mask.mask_fasta([], fasta_file, out_file, mask_from_end=3) + output = SeqIO.parse(out_file, "fasta") + for seq in output: + original = sequences[seq.id] + assert seq.seq[-3:] == "NNN" + assert seq.seq[:-3] == original.seq[:-3] + + def test_mask_fasta_from_beginning_and_end(self, out_file, fasta_file, sequences): + mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=2, mask_from_end=3) + output = SeqIO.parse(out_file, "fasta") + for seq in output: + original = sequences[seq.id] + assert seq.seq[:2] == "NN" + assert seq.seq[-3:] == "NNN" + assert seq.seq[2:-3] == original.seq[2:-3] + + @pytest.mark.parametrize("beginning,end", ((1000,0), (0,1000),(1000,1000))) + def test_mask_fasta_from_beginning_and_end_too_long(self, fasta_file, out_file, beginning, end): + mask.mask_fasta([], fasta_file, out_file, mask_from_beginning=beginning, mask_from_end=end) + output = SeqIO.parse(out_file, "fasta") + for record in output: + assert record.seq == "N" * len(record.seq) + + def test_run_handle_missing_sequence_file(self, vcf_file, argparser): + os.remove(vcf_file) + args = argparser("-s %s" % vcf_file) + with pytest.raises(SystemExit): + mask.run(args) + + def test_run_handle_empty_sequence_file(self, vcf_file, argparser): + open(vcf_file,"w").close() + args = argparser("-s %s --mask-sites 1" % vcf_file) + with pytest.raises(SystemExit): + mask.run(args) + + def test_run_handle_missing_mask_file(self, vcf_file, bed_file, argparser): + os.remove(bed_file) + args = argparser("-s %s --mask %s" % (vcf_file, bed_file)) + with pytest.raises(SystemExit): + mask.run(args) + + def test_run_handle_empty_mask_file(self, vcf_file, bed_file, argparser): + open(bed_file, "w").close() + args = argparser("-s %s --mask %s" % (vcf_file, bed_file)) + with pytest.raises(SystemExit): + mask.run(args) + def test_run_recognize_vcf(self, bed_file, vcf_file, argparser, mp_context): """Ensure we're handling vcf files correctly""" args = argparser("--mask=%s -s %s --no-cleanup" % (bed_file, vcf_file)) @@ -206,7 +266,7 @@ def fail(*args, **kwargs): def test_run_handle_missing_outfile(self, bed_file, fasta_file, argparser, mp_context): args = argparser("--mask=%s -s %s" % (bed_file, fasta_file)) expected_outfile = os.path.join(os.path.dirname(fasta_file), "masked_" + os.path.basename(fasta_file)) - def check_outfile(mask_sites, in_file, out_file): + def check_outfile(mask_sites, in_file, out_file, **kwargs): assert out_file == expected_outfile with open(out_file, "w") as fh: fh.write("test_string") @@ -215,7 +275,7 @@ def check_outfile(mask_sites, in_file, out_file): with open(fasta_file) as fh: assert fh.read() == "test_string" - def test_run_respect_no_cleanup(self, bed_file, tmpdir, vcf_file, argparser, mp_context): + def test_run_respect_no_cleanup(self, bed_file, vcf_file, argparser, mp_context): out_file = os.path.join(os.path.dirname(vcf_file), "masked_" + os.path.basename(vcf_file)) def make_outfile(mask_sites, in_file, out_file, cleanup=True): assert cleanup == False @@ -225,15 +285,106 @@ def make_outfile(mask_sites, in_file, out_file, cleanup=True): mask.run(args) assert os.path.exists(out_file), "Output file incorrectly deleted" - def test_run_normal_case(self, bed_file, vcf_file, tmpdir, argparser, mp_context): - test_outfile = str(tmpdir / "out") - def check_args(mask_sites, in_file, out_file, cleanup): + def test_run_normal_case(self, bed_file, vcf_file, out_file, argparser, mp_context): + def check_args(mask_sites, in_file, _out_file, cleanup): assert mask_sites == TEST_BED_SEQUENCE, "Wrong mask sites provided" assert in_file == vcf_file, "Incorrect input file provided" - assert out_file == test_outfile, "Incorrect output file provided" + assert _out_file == out_file, "Incorrect output file provided" assert cleanup is True, "Cleanup erroneously passed in as False" - open(out_file, "w").close() # want to test we don't delete output. mp_context.setattr(mask, "mask_vcf", check_args) - args = argparser("--mask=%s --sequences=%s --output=%s" %(bed_file, vcf_file, test_outfile)) + args = argparser("--mask=%s --sequences=%s --output=%s" %(bed_file, vcf_file, out_file)) + mask.run(args) + assert os.path.exists(out_file), "Output file incorrectly deleted" + + def test_run_with_mask_sites(self, vcf_file, out_file, argparser, mp_context): + args = argparser("--mask-sites 2 8 -s %s -o %s" % (vcf_file, out_file)) + def check_mask_sites(mask_sites, *args, **kwargs): + # mask-sites are passed to the CLI as one-indexed + assert mask_sites == [1,7] + mp_context.setattr(mask, "mask_vcf", check_mask_sites) + mask.run(args) + + def test_run_with_mask_sites_and_mask_file(self, vcf_file, out_file, bed_file, argparser, mp_context): + args = argparser("--mask-sites 20 21 --mask %s -s %s -o %s" % (bed_file, vcf_file, out_file)) + def check_mask_sites(mask_sites, *args, **kwargs): + # mask-sites are passed to the CLI as one-indexed + assert mask_sites == sorted(set(TEST_BED_SEQUENCE + [19,20])) + mp_context.setattr(mask, "mask_vcf", check_mask_sites) + mask.run(args) + + def test_run_requires_some_masking(self, vcf_file, argparser): + args = argparser("-s %s" % vcf_file) + with pytest.raises(SystemExit) as err: + mask.run(args) + + @pytest.mark.parametrize("op", ("beginning", "end")) + def test_run_vcf_cannot_mask_beginning_or_end(self, vcf_file, argparser, op): + args = argparser("-s %s --mask-from-%s 2" % (vcf_file, op)) + with pytest.raises(SystemExit) as err: + mask.run(args) + + def test_run_fasta_mask_from_beginning_or_end(self, fasta_file, out_file, argparser, mp_context): + args = argparser("-s %s -o %s --mask-from-beginning 2 --mask-from-end 3" % (fasta_file, out_file)) + def check_mask_from(*args, mask_from_beginning=0, mask_from_end=0): + assert mask_from_beginning == 2 + assert mask_from_end == 3 + mp_context.setattr(mask, "mask_fasta", check_mask_from) + mask.run(args) + + def test_e2e_fasta_minimal(self, fasta_file, bed_file, sequences, argparser): + args = argparser("-s %s --mask %s" % (fasta_file, bed_file)) + mask.run(args) + output = SeqIO.parse(fasta_file,"fasta") + for record in output: + reference = sequences[record.id].seq + for idx, site in enumerate(record.seq): + if idx in TEST_BED_SEQUENCE: + assert site == "N" + else: + assert site == reference[idx] + + def test_e2e_fasta_beginning_end_sites(self, fasta_file, bed_file, out_file, sequences, argparser): + from_beginning = 3 + from_end = 1 + arg_sites = [5, 12] + expected_removals = sorted(set(TEST_BED_SEQUENCE + [s - 1 for s in arg_sites])) + print(expected_removals) + args = argparser("-s %s -o %s --mask %s --mask-from-beginning %s --mask-from-end %s --mask-sites %s" % ( + fasta_file, out_file, bed_file, from_beginning, from_end, " ".join(str(s) for s in arg_sites))) + mask.run(args) + output = SeqIO.parse(out_file, "fasta") + for record in output: + reference = str(sequences[record.id].seq) + masked_seq = str(record.seq) + assert masked_seq[:from_beginning] == "N" * from_beginning + assert masked_seq[-from_end:] == "N" * from_end + for idx, site in enumerate(masked_seq[from_beginning:-from_end], from_beginning): + if idx in expected_removals: + assert site == "N" + else: + assert site == reference[idx] + + def test_e2e_vcf_minimal(self, vcf_file, bed_file, argparser): + args = argparser("-s %s --mask %s" % (vcf_file, bed_file)) + mask.run(args) + with open(vcf_file) as output: + assert output.readline().startswith("##fileformat") # is a VCF + assert output.readline().startswith("#CHROM\tPOS\t") # have a header + for line in output.readlines(): + site = int(line.split("\t")[1]) # POS column + site = site - 1 # shift to zero-indexed site + assert site not in TEST_BED_SEQUENCE + + def test_e2e_vcf_with_options(self, vcf_file, bed_file, out_file, argparser): + arg_sites = [5, 12, 14] + expected_removals = sorted(set(TEST_BED_SEQUENCE + [s - 1 for s in arg_sites])) + args = argparser("-s %s -o %s --mask %s --mask-sites %s" % ( + vcf_file, out_file, bed_file, " ".join(str(s) for s in arg_sites))) mask.run(args) - assert os.path.exists(test_outfile), "Output file incorrectly deleted" + with open(out_file) as output: + assert output.readline().startswith("##fileformat") # is a VCF + assert output.readline().startswith("#CHROM\tPOS\t") # have a header + for line in output.readlines(): + site = int(line.split("\t")[1]) # POS column + site = site - 1 #re-zero-index the VCF sites + assert site not in expected_removals