Skip to content

Commit

Permalink
Handle missing weight attributes in KDE frequency estimation
Browse files Browse the repository at this point in the history
Throw an exception if the user has requested weighted KDE frequencies with
weights that do not match any of the tips in the given tree. This commit
explicitly checks for an empty dictionary of weights after filtering for
representation by tips and raise an exception with a meaningful error
message (instead of allowing the code to continue running and throwing a less
meaningful ValueError when no valid weights remain). This commit also adds a
unit test for this behavior.

Closes #425.
  • Loading branch information
huddlej committed Dec 16, 2019
1 parent 1bb1456 commit 6e006b9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
9 changes: 7 additions & 2 deletions augur/frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from Bio.Align import MultipleSeqAlignment

from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies
from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError
from .utils import read_metadata, read_node_data, write_json, get_numerical_dates


Expand Down Expand Up @@ -162,7 +162,12 @@ def run(args):
include_internal_nodes=args.include_internal_nodes,
censored=args.censored
)
frequencies = kde_frequencies.estimate(tree)

try:
frequencies = kde_frequencies.estimate(tree)
except TreeKdeFrequenciesError as e:
print("ERROR: %s" % str(e), file=sys.stderr)
return 1

# Export frequencies in auspice-format by strain name.
frequency_dict = {"pivots": list(kde_frequencies.pivots)}
Expand Down
13 changes: 13 additions & 0 deletions augur/frequency_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
log_thres = 10.0


class TreeKdeFrequenciesError(Exception):
"""Represents an error estimating KDE frequencies for a tree.
"""
pass


def get_pivots(observations, pivot_interval, start_date=None, end_date=None):
"""Calculate pivots for a given list of floating point observation dates and
interval between pivots.
Expand Down Expand Up @@ -1143,6 +1149,13 @@ def estimate(self, tree):
for key, value in self.weights.items():
self.weights[key] = value / weight_total

# Confirm that one or more weights are represented by tips in the
# tree. If there are no more weights, raise an exception because
# this likely represents a data error (either in the tree
# annotations or the weight definitions).
if len(self.weights) == 0:
raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.")

# Estimate frequencies for all tips within each weight attribute
# group.
weight_keys, weight_values = zip(*sorted(self.weights.items()))
Expand Down
15 changes: 14 additions & 1 deletion tests/python3/test_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# we assume (and assert) that this script is running from the tests/ directory
sys.path.append(str(Path(__file__).parent.parent.parent))

from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies
from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError
from augur.utils import json_to_tree

# Define regions to use for testing weighted frequencies.
Expand Down Expand Up @@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree):
# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

# Estimate weighted frequencies such that all weighted attributes are
# missing. This should raise an exception because none of the tips will
# match any of the weights and the weighting of frequencies will be
# impossible.
weights = {"fake_region_1": 1.0, "fake_region_2": 2.0}
kde_frequencies = TreeKdeFrequencies(
weights=weights,
weights_attribute="region"
)

with pytest.raises(TreeKdeFrequenciesError):
frequencies = kde_frequencies.estimate(tree)

def test_only_tip_estimates(self, tree):
"""Test frequency estimation for only tips in a given tree.
"""
Expand Down

0 comments on commit 6e006b9

Please sign in to comment.