Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle missing weight attributes in KDE frequency estimation #426

Merged
merged 2 commits into from
Dec 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions augur/frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from Bio.Align import MultipleSeqAlignment

from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError
from .utils import read_metadata, read_node_data, write_json, get_numerical_dates


Expand Down Expand Up @@ -162,7 +162,12 @@ def run(args):
include_internal_nodes=args.include_internal_nodes,
censored=args.censored
)
frequencies = kde_frequencies.estimate(tree)

try:
frequencies = kde_frequencies.estimate(tree)
except TreeKdeFrequenciesError as e:
print("ERROR: %s" % str(e), file=sys.stderr)
return 1

# Export frequencies in auspice-format by strain name.
frequency_dict = {"pivots": list(kde_frequencies.pivots)}
Expand Down
15 changes: 14 additions & 1 deletion augur/frequency_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
log_thres = 10.0


class TreeKdeFrequenciesError(Exception):
"""Represents an error estimating KDE frequencies for a tree.
"""
pass


def get_pivots(observations, pivot_interval, start_date=None, end_date=None):
"""Calculate pivots for a given list of floating point observation dates and
interval between pivots.
Expand Down Expand Up @@ -1143,6 +1149,13 @@ def estimate(self, tree):
for key, value in self.weights.items():
self.weights[key] = value / weight_total

# Confirm that one or more weights are represented by tips in the
# tree. If there are no more weights, raise an exception because
# this likely represents a data error (either in the tree
# annotations or the weight definitions).
if len(self.weights) == 0:
raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.")

# Estimate frequencies for all tips within each weight attribute
# group.
weight_keys, weight_values = zip(*sorted(self.weights.items()))
Expand All @@ -1153,7 +1166,7 @@ def estimate(self, tree):
# Find tips with the current weight attribute.
tips = [(tip.name, tip.attr["num_date"])
for tip in tree.get_terminals()
if tip.attr[self.weights_attribute].lower() == weight_key and self.tip_passes_filters(tip)]
if tip.attr[self.weights_attribute] == weight_key and self.tip_passes_filters(tip)]
frequencies.update(self.estimate_tip_frequencies_to_proportion(tips, proportion))
else:
tips = [(tip.name, tip.attr["num_date"])
Expand Down
15 changes: 14 additions & 1 deletion tests/python3/test_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# we assume (and assert) that this script is running from the tests/ directory
sys.path.append(str(Path(__file__).parent.parent.parent))

from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies
from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError
from augur.utils import json_to_tree

# Define regions to use for testing weighted frequencies.
Expand Down Expand Up @@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree):
# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

# Estimate weighted frequencies such that all weighted attributes are
# missing. This should raise an exception because none of the tips will
# match any of the weights and the weighting of frequencies will be
# impossible.
weights = {"fake_region_1": 1.0, "fake_region_2": 2.0}
kde_frequencies = TreeKdeFrequencies(
weights=weights,
weights_attribute="region"
)

with pytest.raises(TreeKdeFrequenciesError):
frequencies = kde_frequencies.estimate(tree)

def test_only_tip_estimates(self, tree):
"""Test frequency estimation for only tips in a given tree.
"""
Expand Down