Skip to content

Commit

Permalink
Fix probabilistic subsampling for small values
Browse files Browse the repository at this point in the history
Fixes a regression from the rewrite of augur filter where probabilistic
subsampling for small values of `--subsample-max-sequences` could
randomly select zero strains and randomly cause our CI tests to fail.
Prior to the rewrite of augur filter and introduction of priority
queues, we fixed this issue by repeatedly attempting to calculate
sequences per group that summed to an integer value greater than zero.
However, the way I implemented random queue sizes inside the `PriorityQueue`
class in the rewrite prevented me from using a similar "multiple
attempts" approach.

This commit redesigns the way we create priority queues. In the case
where we already know the number of sequences per group in the first
pass, we create an appropriately-sized priority queue for each group as
we encounter it. There is no possibility that the sum of these queue
sizes could be zero.

In the case where we need to calculate the number of sequences per group
from the first pass, we already know all possible groups and can create
their priority queues in bulk. The new `create_queues_by_group` function
allows us to create fixed-sized or randomly-sized queues and also make
multiple attempts when queue sizes sum to zero. As a result, the
`PriorityQueue` class is much simpler (it requires no logic about random
max sizes) and we can actually test the fixed and random behaviors more
carefully with doctests for `create_queues_by_group`.
  • Loading branch information
huddlej committed Aug 13, 2021
1 parent ae80716 commit 5db04fc
Showing 1 changed file with 63 additions and 33 deletions.
96 changes: 63 additions & 33 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from .utils import is_vcf as filename_is_vcf, read_vcf, read_strains, get_numerical_dates, run_shell_command, shquote, is_date_ambiguous

comment_char = '#'
MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS = 10

SEQUENCE_ONLY_FILTERS = (
"min_length",
Expand Down Expand Up @@ -936,11 +935,6 @@ class PriorityQueue:
"""A priority queue implementation that automatically replaces lower priority
items in the heap with incoming higher priority items.
This implementation also allows the maximum size to be a fractional value
less than 1 in which case the heap size is sampled randomly from a Poisson
distribution with the given maximum size as the mean. This randomly sized
heap enables probabilistic subsampling.
Add a single record to a heap with a maximum of 2 records.
>>> queue = PriorityQueue(max_size=2)
Expand Down Expand Up @@ -974,28 +968,12 @@ class PriorityQueue:
>>> list(queue.get_items())
[{'strain': 'strain4'}, {'strain': 'strain3'}]
Assign a fractional maximum size such that the corresponding queue limit is
sampled randomly from a Poisson distribution. For small values, we should
get a max size that is no more than 10 (this is an arbitrarily high number
above what we see for Poisson samples drawn with a mean of 0.1).
>>> queue = PriorityQueue(max_size=0.1)
>>> queue.max_size in set(range(10))
True
"""
def __init__(self, max_size):
"""Create a fixed size heap (priority queue) that allows the maximum size to be
calculated probabilistically from a Poisson process.
"""Create a fixed size heap (priority queue)
"""
# Fractional heap sizes indicate probabilistic sampling.
if max_size < 1.0:
random_generator = np.random.default_rng()
self.max_size = random_generator.poisson(max_size)
else:
self.max_size = max_size

self.max_size = max_size
self.heap = []
self.counter = itertools.count()

Expand Down Expand Up @@ -1030,10 +1008,53 @@ def get_items(self):
yield item


def priority_queue_factory(max_size):
"""Return a callable for a priority queue with the given arguments.
def create_queues_by_group(groups, max_size, max_attempts=100):
"""Create a dictionary of priority queues per group for the given maximum size.
When the maximum size is fractional, probabilistically sample the maximum
size from a Poisson distribution. Make at least the given number of maximum
attempts to create queues for which the sum of their maximum sizes is
greater than zero.
Create queues for two groups with a fixed maximum size.
>>> groups = ("2015", "2016")
>>> queues = create_queues_by_group(groups, 2)
>>> sum(queue.max_size for queue in queues.values())
4
Create queues for two groups with a fractional maximum size. Their total max
size should still be an integer value greater than zero.
>>> queues = create_queues_by_group(groups, 0.1)
>>> int(sum(queue.max_size for queue in queues.values())) > 0
True
"""
return lambda: PriorityQueue(max_size=max_size)
queues_by_group = {}
total_max_size = 0
attempts = 0

if max_size < 1.0:
random_generator = np.random.default_rng()

# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
while total_max_size == 0 and attempts < max_attempts:
for group in groups:
if max_size < 1.0:
queue_max_size = random_generator.poisson(max_size)
else:
queue_max_size = max_size

queues_by_group[group] = PriorityQueue(queue_max_size)

total_max_size = sum(queue.max_size for queue in queues_by_group.values())
attempts += 1

return queues_by_group


def register_arguments(parser):
Expand Down Expand Up @@ -1347,11 +1368,17 @@ def run(args):
# Track the highest priority records, when we already
# know the number of sequences allowed per group.
if queues_by_group is None:
queues_by_group = defaultdict(priority_queue_factory(
max_size=sequences_per_group,
))
queues_by_group = {}

for strain, group in group_by_strain.items():
# During this first pass, we do not know all possible
# groups will be, so we need to build each group's queue
# as we first encounter the group.
if group not in queues_by_group:
queues_by_group[group] = PriorityQueue(
max_size=sequences_per_group,
)

queues_by_group[group].add(
metadata.loc[strain],
priorities[strain],
Expand Down Expand Up @@ -1411,9 +1438,12 @@ def run(args):
sys.exit(1)

if queues_by_group is None:
queues_by_group = defaultdict(priority_queue_factory(
max_size=sequences_per_group,
))
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
queues_by_group = create_queues_by_group(
records_per_group.keys(),
sequences_per_group,
)

# Make a second pass through the metadata, only considering records that
# have passed filters.
Expand Down

0 comments on commit 5db04fc

Please sign in to comment.