diff --git a/augur/ancestral.py b/augur/ancestral.py index 8c66041e1..3ee0e2828 100644 --- a/augur/ancestral.py +++ b/augur/ancestral.py @@ -72,8 +72,7 @@ def ancestral_sequence_inference(tree=None, aln=None, ref=None, infer_gtr=True, return tt -def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, character_map=None, - mask_ambiguous=True): +def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, character_map=None, is_vcf=False): """iterates of the tree and produces dictionaries with mutations and sequences for each node. @@ -104,20 +103,19 @@ def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, data[n.name]['muts'] = [a+str(int(pos)+inc)+cm(d) for a,pos,d in n.mutations] - mask=None - if full_sequences: - if mask_ambiguous: - # Identify sites for which every terminal sequence is ambiguous. - # These sites will be masked to prevent rounding errors in the - # maximum likelihood inference from assigning an arbitrary - # nucleotide to sites at internal nodes. - ambiguous_count = np.zeros(tt.sequence_length, dtype=int) - for n in tt.tree.get_terminals(): - ambiguous_count += np.array(tt.sequence(n,reconstructed=False, as_string=False)==tt.gtr.ambiguous, dtype=int) - mask = ambiguous_count==len(tt.tree.get_terminals()) - else: - mask = np.zeros(tt.sequence_length, dtype=bool) + if is_vcf: + mask = np.zeros(tt.sequence_length, dtype=bool) + else: + # Identify sites for which every terminal sequence is ambiguous. + # These sites will be masked to prevent rounding errors in the + # maximum likelihood inference from assigning an arbitrary + # nucleotide to sites at internal nodes. + ambiguous_count = np.zeros(tt.sequence_length, dtype=int) + for n in tt.tree.get_terminals(): + ambiguous_count += np.array(tt.sequence(n,reconstructed=False, as_string=False)==tt.gtr.ambiguous, dtype=int) + mask = ambiguous_count==len(tt.tree.get_terminals()) + if full_sequences: for n in tt.tree.find_clades(): try: tmp = tt.sequence(n,reconstructed=infer_tips, as_string=False) @@ -150,7 +148,7 @@ def run_ancestral(T, aln, root_sequence=None, is_vcf=False, full_sequences=False root_seq = tt.sequence(T.root, as_string=True) mutations = collect_mutations_and_sequences(tt, full_sequences=full_sequences, - infer_tips=infer_ambiguous, character_map=character_map) + infer_tips=infer_ambiguous, character_map=character_map, is_vcf=is_vcf) if root_sequence: for pos, (root_state, tree_state) in enumerate(zip(root_sequence, tt.sequence(tt.tree.root, reconstructed=infer_ambiguous, as_string=True))): if root_state != tree_state: @@ -249,9 +247,9 @@ def run(args): # explicitly or by default) and the user has not explicitly requested that # we keep them. infer_ambiguous = args.infer_ambiguous and not args.keep_ambiguous - + full_sequences = args.output_sequences is not None nuc_result = run_ancestral(T, aln, root_sequence=str(ref.seq) if ref else None, is_vcf=is_vcf, fill_overhangs=not args.keep_overhangs, - marginal=args.inference, infer_ambiguous=infer_ambiguous, alphabet='nuc') + full_sequences=full_sequences, marginal=args.inference, infer_ambiguous=infer_ambiguous, alphabet='nuc') anc_seqs = nuc_result['mutations'] anc_seqs['reference'] = {'nuc': nuc_result['root_seq']}