From 6e006b9246312ca2b5d6de46c8ba06bee91595a1 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Mon, 16 Dec 2019 10:48:27 -0800 Subject: [PATCH 1/2] Handle missing weight attributes in KDE frequency estimation Throw an exception if the user has requested weighted KDE frequencies with weights that do not match any of the tips in the given tree. This commit explicitly checks for an empty dictionary of weights after filtering for representation by tips and raise an exception with a meaningful error message (instead of allowing the code to continue running and throwing a less meaningful ValueError when no valid weights remain). This commit also adds a unit test for this behavior. Closes #425. --- augur/frequencies.py | 9 +++++++-- augur/frequency_estimators.py | 13 +++++++++++++ tests/python3/test_frequencies.py | 15 ++++++++++++++- 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/augur/frequencies.py b/augur/frequencies.py index 57db3052e..a1dd3b572 100644 --- a/augur/frequencies.py +++ b/augur/frequencies.py @@ -8,7 +8,7 @@ from Bio.Align import MultipleSeqAlignment from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies -from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies +from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError from .utils import read_metadata, read_node_data, write_json, get_numerical_dates @@ -162,7 +162,12 @@ def run(args): include_internal_nodes=args.include_internal_nodes, censored=args.censored ) - frequencies = kde_frequencies.estimate(tree) + + try: + frequencies = kde_frequencies.estimate(tree) + except TreeKdeFrequenciesError as e: + print("ERROR: %s" % str(e), file=sys.stderr) + return 1 # Export frequencies in auspice-format by strain name. frequency_dict = {"pivots": list(kde_frequencies.pivots)} diff --git a/augur/frequency_estimators.py b/augur/frequency_estimators.py index 6bc3014b7..54eaede65 100644 --- a/augur/frequency_estimators.py +++ b/augur/frequency_estimators.py @@ -12,6 +12,12 @@ log_thres = 10.0 +class TreeKdeFrequenciesError(Exception): + """Represents an error estimating KDE frequencies for a tree. + """ + pass + + def get_pivots(observations, pivot_interval, start_date=None, end_date=None): """Calculate pivots for a given list of floating point observation dates and interval between pivots. @@ -1143,6 +1149,13 @@ def estimate(self, tree): for key, value in self.weights.items(): self.weights[key] = value / weight_total + # Confirm that one or more weights are represented by tips in the + # tree. If there are no more weights, raise an exception because + # this likely represents a data error (either in the tree + # annotations or the weight definitions). + if len(self.weights) == 0: + raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.") + # Estimate frequencies for all tips within each weight attribute # group. weight_keys, weight_values = zip(*sorted(self.weights.items())) diff --git a/tests/python3/test_frequencies.py b/tests/python3/test_frequencies.py index c2f653cff..6f483074e 100644 --- a/tests/python3/test_frequencies.py +++ b/tests/python3/test_frequencies.py @@ -12,7 +12,7 @@ # we assume (and assert) that this script is running from the tests/ directory sys.path.append(str(Path(__file__).parent.parent.parent)) -from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies +from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError from augur.utils import json_to_tree # Define regions to use for testing weighted frequencies. @@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree): # Frequencies should sum to 1 at all pivots. assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots)) + # Estimate weighted frequencies such that all weighted attributes are + # missing. This should raise an exception because none of the tips will + # match any of the weights and the weighting of frequencies will be + # impossible. + weights = {"fake_region_1": 1.0, "fake_region_2": 2.0} + kde_frequencies = TreeKdeFrequencies( + weights=weights, + weights_attribute="region" + ) + + with pytest.raises(TreeKdeFrequenciesError): + frequencies = kde_frequencies.estimate(tree) + def test_only_tip_estimates(self, tree): """Test frequency estimation for only tips in a given tree. """ From f7bff49fbfebc2ca345cc81a43258310457fbe3d Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Tue, 17 Dec 2019 08:00:13 -0800 Subject: [PATCH 2/2] Do not modify the data attached to nodes in the tree during frequency estimation Fixes another bug with frequency estimation caused by a mismatch in the casing of attributes attached to the nodes in the tree, weight attributes in the weights JSON, and what the estimation code originally expected. --- augur/frequency_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/augur/frequency_estimators.py b/augur/frequency_estimators.py index 54eaede65..8237497e5 100644 --- a/augur/frequency_estimators.py +++ b/augur/frequency_estimators.py @@ -1166,7 +1166,7 @@ def estimate(self, tree): # Find tips with the current weight attribute. tips = [(tip.name, tip.attr["num_date"]) for tip in tree.get_terminals() - if tip.attr[self.weights_attribute].lower() == weight_key and self.tip_passes_filters(tip)] + if tip.attr[self.weights_attribute] == weight_key and self.tip_passes_filters(tip)] frequencies.update(self.estimate_tip_frequencies_to_proportion(tips, proportion)) else: tips = [(tip.name, tip.attr["num_date"])