Skip to content

Commit

Permalink
bring back mask
Browse files Browse the repository at this point in the history
  • Loading branch information
rneher committed Jul 22, 2023
1 parent 7913547 commit 8faa4b1
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions augur/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def ancestral_sequence_inference(tree=None, aln=None, ref=None, infer_gtr=True,

return tt

def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, character_map=None,
mask_ambiguous=True):
def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, character_map=None, is_vcf=False):
"""iterates of the tree and produces dictionaries with
mutations and sequences for each node.
Expand Down Expand Up @@ -104,20 +103,19 @@ def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False,
data[n.name]['muts'] = [a+str(int(pos)+inc)+cm(d)
for a,pos,d in n.mutations]

mask=None
if full_sequences:
if mask_ambiguous:
# Identify sites for which every terminal sequence is ambiguous.
# These sites will be masked to prevent rounding errors in the
# maximum likelihood inference from assigning an arbitrary
# nucleotide to sites at internal nodes.
ambiguous_count = np.zeros(tt.sequence_length, dtype=int)
for n in tt.tree.get_terminals():
ambiguous_count += np.array(tt.sequence(n,reconstructed=False, as_string=False)==tt.gtr.ambiguous, dtype=int)
mask = ambiguous_count==len(tt.tree.get_terminals())
else:
mask = np.zeros(tt.sequence_length, dtype=bool)
if is_vcf:
mask = np.zeros(tt.sequence_length, dtype=bool)
else:
# Identify sites for which every terminal sequence is ambiguous.
# These sites will be masked to prevent rounding errors in the
# maximum likelihood inference from assigning an arbitrary
# nucleotide to sites at internal nodes.
ambiguous_count = np.zeros(tt.sequence_length, dtype=int)
for n in tt.tree.get_terminals():
ambiguous_count += np.array(tt.sequence(n,reconstructed=False, as_string=False)==tt.gtr.ambiguous, dtype=int)
mask = ambiguous_count==len(tt.tree.get_terminals())

if full_sequences:
for n in tt.tree.find_clades():
try:
tmp = tt.sequence(n,reconstructed=infer_tips, as_string=False)
Expand Down Expand Up @@ -150,7 +148,7 @@ def run_ancestral(T, aln, root_sequence=None, is_vcf=False, full_sequences=False
root_seq = tt.sequence(T.root, as_string=True)

mutations = collect_mutations_and_sequences(tt, full_sequences=full_sequences,
infer_tips=infer_ambiguous, character_map=character_map)
infer_tips=infer_ambiguous, character_map=character_map, is_vcf=is_vcf)
if root_sequence:
for pos, (root_state, tree_state) in enumerate(zip(root_sequence, tt.sequence(tt.tree.root, reconstructed=infer_ambiguous, as_string=True))):
if root_state != tree_state:
Expand Down Expand Up @@ -249,9 +247,9 @@ def run(args):
# explicitly or by default) and the user has not explicitly requested that
# we keep them.
infer_ambiguous = args.infer_ambiguous and not args.keep_ambiguous

full_sequences = args.output_sequences is not None
nuc_result = run_ancestral(T, aln, root_sequence=str(ref.seq) if ref else None, is_vcf=is_vcf, fill_overhangs=not args.keep_overhangs,
marginal=args.inference, infer_ambiguous=infer_ambiguous, alphabet='nuc')
full_sequences=full_sequences, marginal=args.inference, infer_ambiguous=infer_ambiguous, alphabet='nuc')
anc_seqs = nuc_result['mutations']
anc_seqs['reference'] = {'nuc': nuc_result['root_seq']}

Expand Down

0 comments on commit 8faa4b1

Please sign in to comment.