Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support multiple inputs during filter #697

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 71 additions & 58 deletions augur/align.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""
Align multiple sequences from FASTA.
"""

import hashlib
from itertools import chain
import os
from pathlib import Path
from shutil import copyfile
import numpy as np
from Bio import AlignIO, SeqIO, Seq, Align
from .utils import run_shell_command, nthreads_value, shquote
from collections import defaultdict

from .io import open_file, read_sequences as io_read_sequences, write_sequences


class AlignmentError(Exception):
# TODO: this exception should potentially be renamed and made augur-wide
# thus allowing any module to raise it and have the message printed & augur
Expand Down Expand Up @@ -58,11 +63,12 @@ def prepare(sequences, existing_aln_fname, output, ref_name, ref_seq_fname):
seqs = read_sequences(*sequences)
seqs_to_align_fname = output + ".to_align.fasta"

existing_aln = None
existing_aln_sequence_names = set()

if existing_aln_fname:
existing_aln = read_alignment(existing_aln_fname)
seqs = prune_seqs_matching_alignment(seqs, existing_aln)
else:
existing_aln = None

if ref_seq_fname:
ref_seq = read_reference(ref_seq_fname)
Expand All @@ -72,18 +78,22 @@ def prepare(sequences, existing_aln_fname, output, ref_name, ref_seq_fname):
raise AlignmentError("ERROR: Provided existing alignment ({}bp) is not the same length as the reference sequence ({}bp)".format(existing_aln.get_alignment_length(), len(ref_seq)))
existing_aln_fname = existing_aln_fname + ".ref.fasta"
existing_aln.append(ref_seq)
write_seqs(existing_aln, existing_aln_fname)
existing_aln_sequence_names = write_seqs(existing_aln, existing_aln_fname)
else:
# reference sequence needs to be the first one for auto direction
# adjustment (auto reverse-complement)
seqs.insert(0, ref_seq)
elif ref_name:
ensure_reference_strain_present(ref_name, existing_aln, seqs)
seqs = chain((ref_seq,), seqs)

alignment_sequence_names = write_seqs(seqs, seqs_to_align_fname, ref_name)

write_seqs(seqs, seqs_to_align_fname)
# Check for duplicates in the intersection of existing and new alignment
# sequences.
duplicate_sequence_names = existing_aln_sequence_names & alignment_sequence_names
if len(duplicate_sequence_names) > 0:
raise AlignmentError(
f"Duplicate strains detected: {', '.join(duplicate_sequence_names)}"
)

# 90% sure this is only ever going to catch ref_seq was a dupe
check_duplicates(existing_aln, seqs)
return existing_aln_fname, seqs_to_align_fname, ref_name

def run(args):
Expand Down Expand Up @@ -178,19 +188,34 @@ def postprocess(output_file, ref_name, keep_reference, fill_gaps):

def read_sequences(*fnames):
"""return list of sequences from all fnames"""
seqs = {}
sequence_hash_by_name = {}

try:
for fname in fnames:
for record in SeqIO.parse(fname, 'fasta'):
if record.name in seqs and record.seq != seqs[record.name].seq:
# Stream sequences from all input files into a single output file,
# skipping duplicate records (same strain and sequence) and noting
# mismatched sequences for the same strain name.
for record in io_read_sequences(*fnames):
# Hash each sequence and check whether another sequence with the
# same name already exists and if the hash is different.
sequence_hash = hashlib.sha256(str(record.seq).encode("utf-8")).hexdigest()
if record.name in sequence_hash_by_name:
# If the hashes differ (multiple entries with the same strain
# name but different sequences), we keep the first sequence and
# add the strain to a list of duplicates to report at the end.
if sequence_hash_by_name.get(record.name) != sequence_hash:
raise AlignmentError("Detected duplicate input strains \"%s\" but the sequences are different." % record.name)
# if the same sequence then we can proceed (and we only take one)
seqs[record.name] = record

# If the current strain has been seen before, don't use its
# sequence again.
continue

sequence_hash_by_name[record.name] = sequence_hash
yield record

except FileNotFoundError:
raise AlignmentError("\nCannot read sequences -- make sure the file %s exists and contains sequences in fasta format" % fname)
except ValueError as error:
raise AlignmentError("\nERROR: Problem reading in {}: {}".format(fname, str(error)))
return list(seqs.values())

def check_arguments(args):
# Simple error checking related to a reference name/sequence
Expand All @@ -201,31 +226,29 @@ def check_arguments(args):

def read_alignment(fname):
try:
return AlignIO.read(fname, 'fasta')
with open_file(fname) as handle:
alignment = AlignIO.read(handle, "fasta")

return alignment
except Exception as error:
raise AlignmentError("\nERROR: Problem reading in {}: {}".format(fname, str(error)))

def ensure_reference_strain_present(ref_name, existing_alignment, seqs):
if existing_alignment:
if ref_name not in {x.name for x in existing_alignment}:
raise AlignmentError("ERROR: Specified reference name %s (via --reference-name) is not in the supplied alignment."%ref_name)
else:
if ref_name not in {x.name for x in seqs}:
raise AlignmentError("ERROR: Specified reference name %s (via --reference-name) is not in the sequence sample."%ref_name)


# align
# if args.method=='mafft':
# shoutput = shquote(output)
# shname = shquote(seq_fname)
# cmd = "mafft --reorder --anysymbol --thread %d %s 1> %s 2> %s.log"%(args.nthreads, shname, shoutput, shoutput)

def read_reference(ref_fname):
if not os.path.isfile(ref_fname):
raise AlignmentError("ERROR: Cannot read reference sequence."
"\n\tmake sure the file \"%s\" exists"%ref_fname)

genbank_suffixes = {".gb", ".genbank"}
ref_fname_path = Path(ref_fname)

# Check for GenBank suffixes, while allowing for compression suffixes.
if len(set(ref_fname_path.suffixes) & genbank_suffixes) > 0:
format = "genbank"
else:
format = "fasta"

try:
ref_seq = SeqIO.read(ref_fname, 'genbank' if ref_fname.split('.')[-1] in ['gb', 'genbank'] else 'fasta')
ref_seq = next(io_read_sequences(ref_fname, format=format))
except:
raise AlignmentError("ERROR: Cannot read reference sequence."
"\n\tmake sure the file %s contains one sequence in genbank or fasta format"%ref_fname)
Expand Down Expand Up @@ -388,43 +411,33 @@ def make_gaps_ambiguous(aln):
seq.seq = Seq.Seq(_seq)


def check_duplicates(*values):
names = set()
def add(name):
if name in names:
raise AlignmentError("Duplicate strains of \"{}\" detected".format(name))
names.add(name)
for sample in values:
if not sample:
# allows false-like values (e.g. always provide existing_alignment, allowing
# the default which is `False`)
continue
elif isinstance(sample, (list, Align.MultipleSeqAlignment)):
for s in sample:
add(s.name)
elif isinstance(sample, str):
add(sample)
else:
raise TypeError()

def write_seqs(seqs, fname):
def write_seqs(seqs, fname, ref_name=None):
"""A wrapper around SeqIO.write with error handling"""
sequences_written = set()

try:
SeqIO.write(seqs, fname, 'fasta')
with open_file(fname, "wt") as handle:
for sequence in seqs:
sequences_written.add(sequence.id)
write_sequences(sequence, handle)

except FileNotFoundError:
raise AlignmentError('ERROR: Couldn\'t write "{}" -- perhaps the directory doesn\'t exist?'.format(fname))

if ref_name is not None and ref_name not in sequences_written:
raise AlignmentError(f"ERROR: Specified reference name {ref_name} (via --reference-name) is not in the sequence sample.")

return sequences_written


def prune_seqs_matching_alignment(seqs, aln):
"""
Return a set of seqs excluding those already in the alignment & print a warning
message for each sequence which is exluded.
"""
ret = []
aln_names = {s.name for s in aln}
for seq in seqs:
if seq.name in aln_names:
print("Excluding {} as it is already present in the alignment".format(seq.name))
else:
ret.append(seq)
return ret
yield seq
33 changes: 12 additions & 21 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import treetime.utils

from .index import index_sequences
from .utils import read_metadata, read_strains, get_numerical_dates, run_shell_command, shquote, is_date_ambiguous
from .io import open_file, read_sequences, write_sequences
from .utils import read_metadata, read_sequence_index, read_strains, get_numerical_dates, run_shell_command, shquote, is_date_ambiguous

comment_char = '#'
MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS = 10
Expand Down Expand Up @@ -93,9 +94,9 @@ def filter_by_query(sequences, metadata_file, query):

def register_arguments(parser):
input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered")
input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata, as CSV or TSV")
input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format")
input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.")
input_group.add_argument('--metadata', nargs="+", required=True, metavar="FILE", help="sequence metadata, as CSV or TSV")
input_group.add_argument('--sequences', '-s', nargs="*", help="sequences in FASTA or VCF format")
input_group.add_argument('--sequence-index', nargs="*", help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.")

metadata_filter_group = parser.add_argument_group("metadata filters", "filters to apply to metadata")
metadata_filter_group.add_argument(
Expand Down Expand Up @@ -170,20 +171,13 @@ def run(args):
return 1

# Load inputs, starting with metadata.
try:
# Metadata are the source of truth for which sequences we want to keep
# in filtered output.
meta_dict, meta_columns = read_metadata(args.metadata)
metadata_strains = set(meta_dict.keys())
except ValueError as error:
print("ERROR: Problem reading in {}:".format(args.metadata))
print(error)
return 1
meta_dict, meta_columns = read_metadata(*args.metadata)
metadata_strains = set(meta_dict.keys())

#Set flags if VCF
is_vcf = False
is_compressed = False
if args.sequences and any([args.sequences.lower().endswith(x) for x in ['.vcf', '.vcf.gz']]):
if args.sequences and len(args.sequences) == 1 and any([args.sequences[0].lower().endswith(x) for x in ['.vcf', '.vcf.gz']]):
is_vcf = True
if args.sequences.lower().endswith('.gz'):
is_compressed = True
Expand Down Expand Up @@ -225,10 +219,7 @@ def run(args):
)
index_sequences(args.sequences, sequence_index_path)

sequence_index = pd.read_csv(
sequence_index_path,
sep="\t"
)
sequence_index = read_sequence_index(*sequence_index_path)

# Remove temporary index file, if it exists.
if index_is_autogenerated:
Expand Down Expand Up @@ -545,19 +536,19 @@ def run(args):
dropped_samps = list(available_strains - seq_keep)
write_vcf(args.sequences, args.output, dropped_samps)
elif args.sequences and args.output:
sequences = SeqIO.parse(args.sequences, "fasta")
sequences = read_sequences(*args.sequences)

# Stream to disk all sequences that passed all filters to avoid reading
# sequences into memory first. Track the observed strain names in the
# sequence file as part of the single pass to allow comparison with the
# provided sequence index.
observed_sequence_strains = set()
with open(args.output, "w") as output_handle:
with open_file(args.output, "wt") as output_handle:
for sequence in sequences:
observed_sequence_strains.add(sequence.id)

if sequence.id in seq_keep:
SeqIO.write(sequence, output_handle, 'fasta')
write_sequences(sequence, output_handle, 'fasta')

if sequence_strains != observed_sequence_strains:
# Warn the user if the expected strains from the sequence index are
Expand Down
21 changes: 12 additions & 9 deletions augur/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import sys
import csv

from .io import open_file, read_sequences


def register_arguments(parser):
parser.add_argument('--sequences', '-s', required=True, help="sequences in fasta format")
parser.add_argument('--output', '-o', help="tab-delimited file containing the number of bases per sequence in the given file. Output columns include strain, length, and counts for A, C, G, T, N, other valid IUPAC characters, ambiguous characters ('?' and '-'), and other invalid characters.", required=True)
parser.add_argument('--verbose', '-v', action="store_true", help="print index statistics to stdout")


def index_sequence(sequence, values):
"""Count the number of nucleotides for a given sequence record.

Expand Down Expand Up @@ -127,13 +131,7 @@ def index_sequences(sequences_path, sequence_index_path):
total length of sequences indexed

"""
#read in files
try:
seqs = SeqIO.parse(sequences_path, 'fasta')
except ValueError as error:
print("ERROR: Problem reading in {}:".format(sequences_path), file=sys.stderr)
print(error, file=sys.stderr)
return 1
seqs = read_sequences(sequences_path)

other_IUPAC = {'r', 'y', 's', 'w', 'k', 'm', 'd', 'h', 'b', 'v'}
values = [{'a'},{'c'},{'g'},{'t'},{'n'},other_IUPAC,{'-'},{'?'}]
Expand All @@ -142,7 +140,7 @@ def index_sequences(sequences_path, sequence_index_path):
tot_length = 0
num_of_seqs = 0

with open(sequence_index_path, 'wt') as out_file:
with open_file(sequence_index_path, 'wt') as out_file:
tsv_writer = csv.writer(out_file, delimiter = '\t')

#write header i output file
Expand All @@ -166,7 +164,12 @@ def run(args):
("?" and "-"), and other invalid characters in a set of sequences and write
the composition as a data frame to the given sequence index path.
'''
num_of_seqs, tot_length = index_sequences(args.sequences, args.output)
try:
num_of_seqs, tot_length = index_sequences(args.sequences, args.output)
except ValueError as error:
print("ERROR: Problem reading in {}:".format(sequences_path), file=sys.stderr)
print(error, file=sys.stderr)
return 1

if args.verbose:
print("Analysed %i sequences with an average length of %i nucleotides." % (num_of_seqs, int(tot_length / num_of_seqs)))
Loading