diff --git a/augur/align.py b/augur/align.py index 83de03991..f7b2914e8 100644 --- a/augur/align.py +++ b/augur/align.py @@ -28,6 +28,64 @@ def register_arguments(parser): parser.add_argument('--existing-alignment', metavar="FASTA", default=False, help="An existing alignment to which the sequences will be added. The ouput alignment will be the same length as this existing alignment.") parser.add_argument('--debug', action="store_true", default=False, help="Produce extra files (e.g. pre- and post-aligner files) which can help with debugging poor alignments.") +def prepare(sequences, existing_aln_fname, output, ref_name, ref_seq_fname): + """Prepare the sequences, existing alignment, and reference sequence for alignment. + + This function: + 1. Combines all given input sequences into a single file + 2. Checks to make sure the input sequences don't overlap with the existing alignment, if one exists. + 3. If given a reference name, check that sequence exists in either the existing alignment, if given, or the input sequences. + 4. If given a reference sequence, either add it to the existing alignment or prepend it to the input seqeunces. + 5. Write the input sequences to a single file, and write the alignment back out if we added the reference sequence to it. + + Parameters + ---------- + sequences : list[str] + List of paths to FASTA-formatted sequences to align. + existing_aln_fname : str + Path of an existing alignment to use, or None + output: str + Path the aligned sequences will be written out to. + ref_name: str + The name of the reference sequence, if provided + ref_seq_fname: str + The path to the reference sequence file. If this is provided, it overrides ref_name. + + Returns + ------- + tuple: The existing alignment filename, the new sequences filename, and the name of the reference sequence. + """ + seqs = read_sequences(*sequences) + seqs_to_align_fname = output + ".to_align.fasta" + + if existing_aln_fname: + existing_aln = read_alignment(existing_aln_fname) + seqs = prune_seqs_matching_alignment(seqs, existing_aln) + else: + existing_aln = None + + if ref_seq_fname: + ref_seq = read_reference(ref_seq_fname) + ref_name = ref_seq.id + if existing_aln: + if len(ref_seq) != existing_aln.get_alignment_length(): + raise AlignmentError("ERROR: Provided existing alignment ({}bp) is not the same length as the reference sequence ({}bp)".format(existing_aln.get_alignment_length(), len(ref_seq))) + existing_aln_fname = existing_aln_fname + ".ref.fasta" + existing_aln.append(ref_seq) + write_seqs(existing_aln, existing_aln_fname) + else: + # reference sequence needs to be the first one for auto direction + # adjustment (auto reverse-complement) + seqs.insert(0, ref_seq) + elif ref_name: + ensure_reference_strain_present(ref_name, existing_aln, seqs) + + write_seqs(seqs, seqs_to_align_fname) + + # 90% sure this is only ever going to catch ref_seq was a dupe + check_duplicates(existing_aln, seqs) + return existing_aln_fname, seqs_to_align_fname, ref_name + def run(args): ''' Parameters @@ -44,48 +102,14 @@ def run(args): try: check_arguments(args) - seqs = read_sequences(*args.sequences) - existing_aln = read_alignment(args.existing_alignment) if args.existing_alignment else None - - # if we have been given a reference (strain) name, make sure it is present - ref_name = args.reference_name - if args.reference_name: - ensure_reference_strain_present(ref_name, existing_aln, seqs) - - # If given an existing alignment, then add the reference sequence to this if desired (and if it is the same length) - if existing_aln and args.reference_sequence: - existing_aln_fname = args.existing_alignment + ".ref.fasta" - ref_seq = read_reference(args.reference_sequence) - if len(ref_seq) != existing_aln.get_alignment_length(): - raise AlignmentError("ERROR: Provided existing alignment ({}bp) is not the same length as the reference sequence ({}bp)".format(existing_aln.get_alignment_length(), len(ref_seq))) - existing_aln.append(ref_seq) - write_seqs(existing_aln, existing_aln_fname) - temp_files_to_remove.append(existing_aln_fname) - ref_name = ref_seq.id - else: - existing_aln_fname = args.existing_alignment # may be False - - ## Create a single file of sequences for alignment (or to be added to the alignment). - ## Add in the reference file to the sequences _if_ we don't have an existing alignment - if args.reference_sequence and not existing_aln: - seqs_to_align_fname = args.output+".to_align.fasta" - ref_seq = read_reference(args.reference_sequence) - # reference sequence needs to be the first one for auto direction adjustment (auto reverse-complement) - write_seqs([ref_seq] + list(seqs.values()), seqs_to_align_fname) - ref_name = ref_seq.id - elif existing_aln: - seqs_to_align_fname = args.output+".new_seqs_to_align.fasta" - seqs = prune_seqs_matching_alignment(seqs, existing_aln) - write_seqs(list(seqs.values()), seqs_to_align_fname) - else: - seqs_to_align_fname = args.output+".to_align.fasta" - write_seqs(list(seqs.values()), seqs_to_align_fname) + existing_aln_fname, seqs_to_align_fname, ref_name = prepare(args.sequences, args.existing_alignment, args.output, args.reference_name, args.reference_sequence) temp_files_to_remove.append(seqs_to_align_fname) - - check_duplicates(existing_aln, ref_name, seqs) + if existing_aln_fname != args.existing_alignment: + temp_files_to_remove.append(existing_aln_fname) + # -- existing_aln_fname, seqs_to_align_fname, ref_name -- # before aligning, make a copy of the data that the aligner receives as input (very useful for debugging purposes) - if args.debug and not existing_aln: + if args.debug and not existing_aln_fname: copyfile(seqs_to_align_fname, args.output+".pre_aligner.fasta") # generate alignment command & run @@ -98,24 +122,7 @@ def run(args): if args.debug: copyfile(args.output, args.output+".post_aligner.fasta") - # reads the new alignment - seqs = read_alignment(args.output) - - # convert the aligner output to upper case and remove auto reverse-complement prefix - prettify_alignment(seqs) - - # if we've specified a reference, strip out all the columns not present in the reference - # this will overwrite the alignment file - if ref_name: - seqs = strip_non_reference(seqs, ref_name, insertion_csv=args.output+".insertions.csv") - if args.remove_reference: - seqs = remove_reference_sequence(seqs, ref_name) - write_seqs(seqs, args.output) - if args.fill_gaps: - make_gaps_ambiguous(seqs) - - # write the modified sequences back to the alignment file - write_seqs(seqs, args.output) + postprocess(args.output, ref_name, not args.remove_reference, args.fill_gaps) except AlignmentError as e: @@ -126,9 +133,50 @@ def run(args): for fname in temp_files_to_remove: os.remove(fname) + +def postprocess(output_file, ref_name, keep_reference, fill_gaps): + """Postprocessing of the combined alignment file. + + Parameters + ---------- + output_file: str + The file the new alignment was written to + ref_name: str + If provided, the name of the reference strain used in the alignment + keep_reference: bool + If the reference was provided, whether it should be kept in the alignment + fill_gaps: bool + Replace all gaps in the alignment with "N" to indicate ambiguous sites. + + Returns + ------- + None - the modified alignment is written directly to output_file + """ + # -- ref_name -- + # reads the new alignment + seqs = read_alignment(output_file) + # convert the aligner output to upper case and remove auto reverse-complement prefix + prettify_alignment(seqs) + + # if we've specified a reference, strip out all the columns not present in the reference + # this will overwrite the alignment file + if ref_name: + seqs = strip_non_reference(seqs, ref_name, insertion_csv=output_file+".insertions.csv") + if not keep_reference: + seqs = remove_reference_sequence(seqs, ref_name) + + if fill_gaps: + make_gaps_ambiguous(seqs) + + # write the modified sequences back to the alignment file + write_seqs(seqs, output_file) + + + ##################################################################################################### def read_sequences(*fnames): + """return list of sequences from all fnames""" seqs = {} try: for fname in fnames: @@ -141,7 +189,7 @@ def read_sequences(*fnames): raise AlignmentError("\nCannot read sequences -- make sure the file %s exists and contains sequences in fasta format" % fname) except ValueError as error: raise AlignmentError("\nERROR: Problem reading in {}: {}".format(fname, str(error))) - return seqs + return list(seqs.values()) def check_arguments(args): # Simple error checking related to a reference name/sequence @@ -161,7 +209,7 @@ def ensure_reference_strain_present(ref_name, existing_alignment, seqs): if ref_name not in {x.name for x in existing_alignment}: raise AlignmentError("ERROR: Specified reference name %s (via --reference-name) is not in the supplied alignment."%ref_name) else: - if ref_name not in seqs: + if ref_name not in {x.name for x in seqs}: raise AlignmentError("ERROR: Specified reference name %s (via --reference-name) is not in the sequence sample."%ref_name) @@ -345,19 +393,15 @@ def add(name): if name in names: raise AlignmentError("Duplicate strains of \"{}\" detected".format(name)) names.add(name) - for sample in values: if not sample: # allows false-like values (e.g. always provide existing_alignment, allowing # the default which is `False`) continue - elif type(sample) == dict: - for s in sample: - add(s) - elif type(sample) == Align.MultipleSeqAlignment: + elif isinstance(sample, (list, Align.MultipleSeqAlignment)): for s in sample: add(s.name) - elif type(sample) == str: + elif isinstance(sample, str): add(sample) else: raise TypeError() @@ -372,14 +416,14 @@ def write_seqs(seqs, fname): def prune_seqs_matching_alignment(seqs, aln): """ - Return a set of seqs excluding those set via `exclude` & print a warning + Return a set of seqs excluding those already in the alignment & print a warning message for each sequence which is exluded. """ - ret = {} - exclude_names = {s.name for s in aln} - for name, seq in seqs.items(): - if name in exclude_names: - print("Excluding {} as it is already present in the alignment".format(name)) + ret = [] + aln_names = {s.name for s in aln} + for seq in seqs: + if seq.name in aln_names: + print("Excluding {} as it is already present in the alignment".format(seq.name)) else: - ret[name] = seq + ret.append(seq) return ret diff --git a/augur/filter.py b/augur/filter.py index 700bdff99..efcaab8bb 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -8,6 +8,8 @@ import random, os, re import numpy as np import sys +import datetime +import treetime.utils from .utils import read_metadata, get_numerical_dates, run_shell_command, shquote comment_char = '#' @@ -87,8 +89,8 @@ def filter_by_query(sequences, metadata_file, query): def register_arguments(parser): parser.add_argument('--sequences', '-s', required=True, help="sequences in fasta or VCF format") parser.add_argument('--metadata', required=True, help="metadata associated with sequences") - parser.add_argument('--min-date', type=float, help="minimal cutoff for numerical date") - parser.add_argument('--max-date', type=float, help="maximal cutoff for numerical date") + parser.add_argument('--min-date', type=numeric_date, help="minimal cutoff for date; may be specified as an Augur-style numeric date (with the year as the integer part) or YYYY-MM-DD") + parser.add_argument('--max-date', type=numeric_date, help="maximal cutoff for date; may be specified as an Augur-style numeric date (with the year as the integer part) or YYYY-MM-DD") parser.add_argument('--min-length', type=int, help="minimal length of the sequences") parser.add_argument('--non-nucleotide', action='store_true', help="exclude sequences that contain illegal characters") parser.add_argument('--exclude', type=str, help="file with list of strains that are to be excluded") @@ -410,3 +412,21 @@ def run(args): def _filename_gz(filename): return filename.lower().endswith(".gz") + + +def numeric_date(date): + """ + Converts the given *date* string to a :py:class:`float`. + + *date* may be given as a number (a float) with year as the integer part, or + in the YYYY-MM-DD (ISO 8601) syntax. + + >>> numeric_date("2020.42") + 2020.42 + >>> numeric_date("2020-06-04") + 2020.42486... + """ + try: + return float(date) + except ValueError: + return treetime.utils.numeric_date(datetime.date(*map(int, date.split("-", 2)))) diff --git a/augur/util_support/__init__.py b/augur/util_support/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/augur/util_support/date_disambiguator.py b/augur/util_support/date_disambiguator.py new file mode 100644 index 000000000..713295eac --- /dev/null +++ b/augur/util_support/date_disambiguator.py @@ -0,0 +1,123 @@ +import calendar +import datetime +import functools +import re + + +def tuple_to_date(year, month, day): + month = min(month, 12) + day = min(day, max_day_for_year_month(year, month)) + + return datetime.date(year=year, month=month, day=day) + + +def max_day_for_year_month(year, month): + return calendar.monthrange(year, month)[1] + + +def resolve_uncertain_int(uncertain_string, min_or_max): + """ + Takes a string representation of an integer with uncertain places + occupied by the character `X`. Returns the minimum or maximum + possible integer. + """ + if min_or_max == "min": + result = int(uncertain_string.replace("X", "0")) + elif min_or_max == "max": + result = int(uncertain_string.replace("X", "9")) + else: + raise "Tried to resolve an uncertain integer to something other than `min` or `max`." + + if result == 0: + # A date component cannot be 0. Well, year can, but... + result = 1 + + return result + + +class DateDisambiguator: + """Transforms a date string with uncertainty into the range of possible dates.""" + + def __init__(self, uncertain_date, fmt="%Y-%m-%d", min_max_year=None): + self.uncertain_date = uncertain_date + self.fmt = fmt + self.min_max_year = min_max_year + + self.assert_only_less_significant_uncertainty() + + def range(self): + min_date = tuple_to_date( + resolve_uncertain_int(self.uncertain_date_components["Y"], "min"), + resolve_uncertain_int(self.uncertain_date_components["m"], "min"), + resolve_uncertain_int(self.uncertain_date_components["d"], "min"), + ) + + max_date = tuple_to_date( + resolve_uncertain_int(self.uncertain_date_components["Y"], "max"), + resolve_uncertain_int(self.uncertain_date_components["m"], "max"), + resolve_uncertain_int(self.uncertain_date_components["d"], "max"), + ) + max_date = min(max_date, datetime.date.today()) + + return (min_date, max_date) + + @property + @functools.lru_cache() + def uncertain_date_components(self): + matches = re.search(self.regex, self.uncertain_date) + + if matches is None: + raise ValueError( + f"Malformed uncertain date `{self.uncertain_date}` for format `{self.fmt}`" + ) + + return dict(zip(self.fmt_components, matches.groups())) + + @property + @functools.lru_cache() + def fmt_components(self): + # The `re` module doesn't capture repeated groups, so we'll do it without regexes + return [component[0] for component in self.fmt.split("%") if len(component) > 0] + + @property + def regex(self): + """ + Returns regex defined by the format string. + Currently only supports %Y, %m, and %d. + """ + return re.compile( + "^" + + self.fmt.replace("%Y", "(....)") + .replace("%m", "(..?)") + .replace("%d", "(..?)") + + "$" + ) + + def assert_only_less_significant_uncertainty(self): + """ + Raise an exception if a constrained digit appears in a less-significant place + than an uncertain digit. + + Assuming %Y-%m-%d, these patterns are valid: + 2000-01-01 + 2000-01-XX + 2000-XX-XX + + but this is invalid, because month is uncertain but day is constrained: + 2000-XX-01 + + These invalid cases are assumed to be unintended use of the tool. + """ + if "X" in self.uncertain_date_components["Y"]: + if ( + self.uncertain_date_components["m"] != "XX" + or self.uncertain_date_components["d"] != "XX" + ): + raise ValueError( + "Invalid date: Year contains uncertainty, so month and day must also be uncertain." + ) + elif "X" in self.uncertain_date_components["m"]: + if self.uncertain_date_components["d"] != "XX": + raise ValueError( + "Invalid date: Month contains uncertainty, so day must also be uncertain." + ) diff --git a/augur/utils.py b/augur/utils.py index 6f8531c84..0449da24e 100644 --- a/augur/utils.py +++ b/augur/utils.py @@ -16,6 +16,8 @@ import packaging.version as packaging_version from .validate import validate, ValidateError, load_json_schema +from augur.util_support.date_disambiguator import DateDisambiguator + class AugurException(Exception): pass @@ -62,38 +64,8 @@ def get_json_name(args, default=None): raise ValueError("Please specify a name for the JSON file containing the results.") -def ambiguous_date_to_date_range(mydate, fmt, min_max_year=None): - from datetime import datetime - sep = fmt.split('%')[1][-1] - min_date, max_date = {}, {} - today = datetime.today().date() - - for val, field in zip(mydate.split(sep), fmt.split(sep+'%')): - f = 'year' if 'y' in field.lower() else ('day' if 'd' in field.lower() else 'month') - if 'XX' in val: - if f=='year': - if min_max_year: - min_date[f]=min_max_year[0] - if len(min_max_year)>1: - max_date[f]=min_max_year[1] - elif len(min_max_year)==1: - max_date[f]=4000 #will be replaced by 'today' below. - else: - return None, None - elif f=='month': - min_date[f]=1 - max_date[f]=12 - elif f=='day': - min_date[f]=1 - max_date[f]=31 - else: - min_date[f]=int(val) - max_date[f]=int(val) - max_date['day'] = min(max_date['day'], 31 if max_date['month'] in [1,3,5,7,8,10,12] - else 28 if max_date['month']==2 else 30) - lower_bound = datetime(year=min_date['year'], month=min_date['month'], day=min_date['day']).date() - upper_bound = datetime(year=max_date['year'], month=max_date['month'], day=max_date['day']).date() - return (lower_bound, upper_bound if upper_bound