diff --git a/augur/frequencies.py b/augur/frequencies.py index 57db3052e..a1dd3b572 100644 --- a/augur/frequencies.py +++ b/augur/frequencies.py @@ -8,7 +8,7 @@ from Bio.Align import MultipleSeqAlignment from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies -from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies +from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError from .utils import read_metadata, read_node_data, write_json, get_numerical_dates @@ -162,7 +162,12 @@ def run(args): include_internal_nodes=args.include_internal_nodes, censored=args.censored ) - frequencies = kde_frequencies.estimate(tree) + + try: + frequencies = kde_frequencies.estimate(tree) + except TreeKdeFrequenciesError as e: + print("ERROR: %s" % str(e), file=sys.stderr) + return 1 # Export frequencies in auspice-format by strain name. frequency_dict = {"pivots": list(kde_frequencies.pivots)} diff --git a/augur/frequency_estimators.py b/augur/frequency_estimators.py index 6bc3014b7..54eaede65 100644 --- a/augur/frequency_estimators.py +++ b/augur/frequency_estimators.py @@ -12,6 +12,12 @@ log_thres = 10.0 +class TreeKdeFrequenciesError(Exception): + """Represents an error estimating KDE frequencies for a tree. + """ + pass + + def get_pivots(observations, pivot_interval, start_date=None, end_date=None): """Calculate pivots for a given list of floating point observation dates and interval between pivots. @@ -1143,6 +1149,13 @@ def estimate(self, tree): for key, value in self.weights.items(): self.weights[key] = value / weight_total + # Confirm that one or more weights are represented by tips in the + # tree. If there are no more weights, raise an exception because + # this likely represents a data error (either in the tree + # annotations or the weight definitions). + if len(self.weights) == 0: + raise TreeKdeFrequenciesError("None of the provided weight attributes were represented by tips in the given tree. Doublecheck weight attribute definitions and their representations in the tree.") + # Estimate frequencies for all tips within each weight attribute # group. weight_keys, weight_values = zip(*sorted(self.weights.items())) diff --git a/tests/python3/test_frequencies.py b/tests/python3/test_frequencies.py index c2f653cff..6f483074e 100644 --- a/tests/python3/test_frequencies.py +++ b/tests/python3/test_frequencies.py @@ -12,7 +12,7 @@ # we assume (and assert) that this script is running from the tests/ directory sys.path.append(str(Path(__file__).parent.parent.parent)) -from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies +from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies, TreeKdeFrequenciesError from augur.utils import json_to_tree # Define regions to use for testing weighted frequencies. @@ -173,6 +173,19 @@ def test_weighted_estimate_with_unrepresented_weights(self, tree): # Frequencies should sum to 1 at all pivots. assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots)) + # Estimate weighted frequencies such that all weighted attributes are + # missing. This should raise an exception because none of the tips will + # match any of the weights and the weighting of frequencies will be + # impossible. + weights = {"fake_region_1": 1.0, "fake_region_2": 2.0} + kde_frequencies = TreeKdeFrequencies( + weights=weights, + weights_attribute="region" + ) + + with pytest.raises(TreeKdeFrequenciesError): + frequencies = kde_frequencies.estimate(tree) + def test_only_tip_estimates(self, tree): """Test frequency estimation for only tips in a given tree. """