Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

filter: Rewrite priority queue logic with pandas functions #809

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

victorlin
Copy link
Member

@victorlin victorlin commented Dec 10, 2021

Description of proposed changes

  • remove class PriorityQueue
  • use prioritized_metadata DataFrame in place of queues_per_group
    • this also removes the need for a second call to get_groups_for_subsampling
  • repurpose create_queues_per_group to create_sizes_per_group
  • other logical refactoring:
    • use global dummy group key and value
      • key is list: pd.DataFrame.groupby does not take a tuple as grouping key, also our --group-by is stored as list already.
      • value is tuple: get_groups_for_subsampling returns group values in this form.
    • use records_per_group for _dummy
      • replace conditional logic of records_per_group is not None with group_by
      • add TODO to rewrite logic for records_per_group with pandas functions
    • move date-expanding logic from get_groups_for_subsampling to a separate function: 1b106d5
    • store unique group values globally: 4e3e155, 6296d8e

Related issue(s)

Testing

  • Added functional tests for groupby

- remove `class PriorityQueue`
- use `prioritized_metadata` DataFrame in place of `queues_per_group`
- repurpose `create_queues_per_group` to `create_sizes_per_group`
- other logical refactoring:
    - use global dummy group key and value
        - key is `list`: pd.DataFrame.groupby does not take a tuple as grouping key, also our `--group-by` is stored as list already.
        - value is `tuple: `get_groups_for_subsampling` currently returns group values in this form.
    - use records_per_group for _dummy
        - replace conditional logic of `records_per_group is not None` with `group_by`
- add functional tests
@codecov
Copy link

codecov bot commented Dec 10, 2021

Codecov Report

Merging #809 (eea96fb) into master (a3a79ca) will decrease coverage by 0.20%.
The diff coverage is 48.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #809      +/-   ##
==========================================
- Coverage   33.81%   33.60%   -0.21%     
==========================================
  Files          41       41              
  Lines        5900     5933      +33     
  Branches     1465     1480      +15     
==========================================
- Hits         1995     1994       -1     
- Misses       3822     3854      +32     
- Partials       83       85       +2     
Impacted Files Coverage Δ
augur/filter.py 65.88% <48.00%> (-2.29%) ⬇️
augur/refine.py 5.02% <0.00%> (-0.44%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a3a79ca...eea96fb. Read the comment docs.

@victorlin victorlin changed the title Rewrite PriorityQueue logic with pandas functions filter: Rewrite priority queue logic with pandas functions Dec 10, 2021
@victorlin victorlin force-pushed the victorlin/filter/priority-speedup branch from aba5515 to 6296d8e Compare December 11, 2021 00:19
@victorlin victorlin force-pushed the victorlin/filter/priority-speedup branch from 6296d8e to 9cf2264 Compare December 11, 2021 01:08
@victorlin victorlin force-pushed the victorlin/filter/priority-speedup branch from b5101bc to e01d302 Compare December 11, 2021 01:14
This test currently fails with a pandas-specific index error.
Instead of calculating a new (year, month) tuple when users group by
month, add a "year" key to the list of group fields. This fixes a pandas
indexing bug where calling `nlargest` on a SeriesGroupBy object that has
a year and month tuple key for month causes pandas to think the single
month key is a MultiIndex that should be a list. Although this commit is
motivated to fix this pandas issue, this implementation of the
year/month disambiguation is simpler and a more idiomatic pandas
implementation that wouldn't have been possible in the original augur
filter implementation (before we switched to pandas for metadata
parsing).
Simplifies unit tests and doctests by expecting a single value for each
month instead of a tuple.
Copy link
Contributor

@huddlej huddlej left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this huge effort, @victorlin! This PR represents a lot of work both in managing complexity of the filter module and coming up with a solution to use vectorized operations with pandas.

Before I get to the specific implementation details, here are some performance results for the code in this PR vs. Augur 13.1.0. First, I got a random sample of 10 strains from Nebraska with the full GenBank (“open”) metadata. This command took slightly longer to run with the current implementation.

# Get 10 random sequences with "dummy" group.
# Ran in 3.5 min instead of 2 min pre-PR.
augur filter \
  --metadata metadata.tsv.gz \
  --query "division == 'Nebraska'" \
  --subsample-max-sequences 10 \
  --output-metadata nebraska_random_metadata.tsv \
  --output-strains nebraska_random_strains.txt

Next, I tried getting 10 random sequences per month. This command ran for 3 minutes before crashing.

# Get 10 random sequences by month.
# Ran in 3 min before crashing.
augur filter \
  --metadata metadata.tsv.gz \
  --query "division == 'Nebraska'" \
  --subsample-max-sequences 10 \
  --group-by month \
  --output-metadata nebraska_random_metadata.tsv \
  --output-strains nebraska_random_strains.txt

The exception from the command above was:

Traceback (most recent call last):
  File "/Users/jlhudd/miniconda3/envs/nextstrain/bin/augur", line 33, in <module>
    sys.exit(load_entry_point('nextstrain-augur', 'console_scripts', 'augur')())
  File "/Users/jlhudd/projects/nextstrain/augur/augur/__main__.py", line 10, in main
    return augur.run( argv[1:] )
  File "/Users/jlhudd/projects/nextstrain/augur/augur/__init__.py", line 75, in run
    return args.__command__.run(args)
  File "/Users/jlhudd/projects/nextstrain/augur/augur/filter.py", line 1479, in run
    metadata_copy.groupby(group_by, sort=False)['priority']
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/groupby/groupby.py", line 948, in wrapper
    return self._python_apply_general(curried, self._obj_with_exclusions)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/groupby/groupby.py", line 1311, in _python_apply_general
    return self._wrap_applied_output(
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/groupby/generic.py", line 472, in _wrap_applied_output
    return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/groupby/groupby.py", line 1044, in _concat_objects
    result = concat(
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/util/_decorators.py", line 311, in wrapper
    return func(*args, **kwargs)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/reshape/concat.py", line 294, in concat
    op = _Concatenator(
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/reshape/concat.py", line 371, in __init__
    keys = Index(clean_keys, name=name)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 489, in __new__
    return MultiIndex.from_tuples(
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/multi.py", line 202, in new_meth
    return meth(self_or_cls, *args, **kwargs)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/multi.py", line 560, in from_tuples
    return cls.from_arrays(arrays, sortorder=sortorder, names=names)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/multi.py", line 487, in from_arrays
    return cls(
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/multi.py", line 331, in __new__
    result._set_names(names)
  File "/Users/jlhudd/miniconda3/envs/nextstrain/lib/python3.8/site-packages/pandas/core/indexes/multi.py", line 1423, in _set_names
    raise ValueError("Names should be list-like for a MultiIndex")
ValueError: Names should be list-like for a MultiIndex

The error specifically occurs with this call:

metadata_copy.groupby(group_by, sort=False)["priority"].nlargest(int_group_size, keep='last')

where group_by is month and int_group_size is 4. It turns out that pandas (1.3.4) interprets the (year, month) tuple in the month column as a MultiIndex and expects the values to be in a list. Commit 897d00e adds a functional test that fails with this specific error.

This feels like a pandas edge case, but it also highlights a strange implementation choice in our code. In the original augur filter, we grouped by storing records in dictionaries by group key. At the time, it seemed to make sense to store month values as year/month tuples, so avoid meaninglessly sampling across years when grouping by month alone. This choice no longer makes much sense now that we use pandas DataFrames where year and month have their own columns and grouping on these is simple. If we want to implicitly group by year and month when the user asks for month, I think that logic should live closer to the highest level of the filter module and not live in the filter internal functions. Commit 966da1d makes this change and commit eea96fb updates unit tests and doctests to reflect this change.

With this fix in place, I re-ran the filter for 10 random Nebraska sequences grouped by month. It ran in 4 minutes, the same amount of time required with Augur 13.1.0.

Next, I ran Trevor’s more complex filter command on the full sanitized GISAID data as shown below. This took 11 minutes with this PR and 17 minutes with Augur 13.1.0.

# Run Trevor's filter command on full sanitized GISAID data.
# This took 24 minutes with Augur 13.0.3.
# It took 17 minutes with Augur 13.1.0.
# It took 11 minutes with the new groupby priorities logic.
time augur filter \
 --metadata sanitized_metadata_gisaid.tsv.xz \
 --include defaults/include.txt \
 --exclude defaults/exclude.txt \
 --min-date 2021-07-28 \
 --exclude-where 'region!=Europe' \
 --group-by country year month \
 --subsample-max-sequences 10 \
 --probabilistic-sampling \
 --output-strains sample-europe_late.txt

For reference, here are some equivalent commands with tsv-utils for the two Nebraska examples above (it’s less obvious how to completely map Trevor’s filter command to tsv-utils, though it seems possible):

# Get all Nebraska metadata.
# Runs in 4 seconds.
time gzip -c -d metadata.tsv.gz | tsv-filter -H --str-eq division:Nebraska > all_nebraska_metadata.tsv

# Get 10 random metadata records from Nebraska.
# Runs in 3 seconds.
gzip -c -d metadata.tsv.gz | tsv-filter -H --str-eq division:Nebraska | tsv-sample -H --n 10 > random_nebraska_metadata.tsv

Overall, there seems to be a slight performance loss to the new implementation for smaller queries and a performance increase for larger queries. This outcome makes sense to me for a couple of reasons:

  1. We now make two passes through the metadata for all group-by queries whereas the 13.1.0 implementation only makes a second pass when the number of sequences per group is unknown (--subsample-max-sequences and --group-by are defined). I would expect the code in this PR to be slightly slower for the use case where users provide a known number of sequences per group.
  2. We now consider the number of sequences per group as unknown when the user provides --subsample-max-sequences and no --group-by values whereas the 13.1.0 knows the resulting “dummy” group needs as many sequences per group as total requested sequences. I suspect this is why the first Nebraska filter above is slightly slower now than in 13.1.0.

On the other hand, we are not making two calls to get_groups_for_subsampling now and we use vectorized operations to keep the highest priority records. These changes clearly speed up the use case where we needed to make a second pass through the metadata anyway.

To test the hypothesis that --sequences-per-group should be slower with this PR’s implementation (due to the required second pass compared to a single pass in Augur 13.1.0), I ran the following command with this PR’s code and Augur 13.1.0:

# Get 1 random sequence per by month (and year).
# This ran in 2 minutes with Augur 13.1.0.
# This ran in 4 minutes with this PR.
augur filter \
  --metadata metadata.tsv.gz \
  --query "division == 'Nebraska'" \
  --sequences-per-group 1 \
  --group-by month \
  --output-metadata nebraska_random_metadata.tsv \
  --output-strains nebraska_random_strains.txt

The original code runs in 2 minutes while the new code runs in 4 minutes. I would expect these values to diverge even more as the metadata input size groups (the cost of looping through fixed-size priority queues will remain the same in Augur 13.1.0, but the cost of the second pass will increase in this PR). To test this second hypothesis, I repeated the command above with the sanitized GISAID database. This database has 5,517,277 records compared to 2,493,771 records in the open database.

# Get 1 random sequence per by month (and year) from GISAID.
# This ran in 4.5 minutes with Augur 13.1.0.
# This ran in 9 minutes with this PR.
augur filter \
  --metadata sanitized_metadata_gisaid.tsv.xz \
  --query "division == 'Nebraska'" \
  --sequences-per-group 1 \
  --group-by month \
  --output-metadata nebraska_random_metadata.tsv \
  --output-strains nebraska_random_strains.txt

The command above ran in 9 minutes with this PR and 4.5 minutes with Augur 13.1.0. This result is consistent with some fixed cost to loop through the metadata twice. I would like to maintain the high performance of these alternate use cases while also speeding up the use case Trevor originally described.

We can think of the new code you’ve written to build the prioritized metadata data frame as analogous to the priority queues in 13.1.0. The priority queues exist to capture the N highest priority records in memory. When we know the number of sequences per group during the first pass through the metadata (i.e., either --sequences-per-group or just --subsampling-max-sequences with a dummy group), we can add to the priority queues in that first pass and then emit the records from the queue without a second pass. We only need a second pass through the metadata when we need to calculate the number of sequences per group from information in the first pass.

If we refactor and abstract the logic for building the prioritized data frame into its own function or class, we could use the same approach to build a prioritized data frame in the first pass or an optional second pass. This function or class would abstract some of the pandas-related complexity in that second pass and allow us to speed up all of our use cases. This approach would also be consistent with the original goal of this PR: to use vectorized data frame operations to track the highest priority records instead of the slower priority queues.

I left a comment to this effect inline below asking about how you might define a function for the code inside the for loop for the second pass. I’m happy to chat about this more over video/audio, too.

In addition to the major points above, I left several other smaller comments inline below. Most are questions about pandas syntax.

Comment on lines -1299 to +1267
if group_by and args.subsample_max_sequences:
# In this case, we need two passes through the metadata with the first
# pass used to count the number of records per group.
if args.subsample_max_sequences:
records_per_group = defaultdict(int)
elif not group_by and args.subsample_max_sequences:
group_by = ("_dummy",)
sequences_per_group = args.subsample_max_sequences
if not group_by:
group_by = dummy_group
sequences_per_group = args.subsample_max_sequences
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this reduces the number of branch logic lines, I don't think this reorg makes the logic clearer. It causes records_per_group to be instantiated in cases when it isn't needed and drops a clarifying comment that explains the existence of that variable.

@@ -1409,32 +1373,13 @@ def run(args):
if args.output_log:
output_log_writer.writerow(skipped_strain)

if args.subsample_max_sequences and records_per_group is not None:
if args.subsample_max_sequences and group_by:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may become clearer later but related to my comment above, we don't need to count records per group in the case that the "dummy" group is used; we already know the number of sequences per group equals the max number requested because there is only one group.

Would you say that the extra cost of calculating sequences per group for the dummy group is worth the benefit of simpler code downstream (handling one less branch in the logic)?

@@ -1470,11 +1415,12 @@ def run(args):
for strain in strains_to_write:
output_strains.write(f"{strain}\n")

probabilistic_used = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line feels orphaned from its related logic. If you move this below the comment, the comment will still make sense, but this line will have closer context to justify itself.

seq_keep_ordered = [seq for seq in metadata.index if seq in seq_keep]
metadata_copy = metadata.loc[seq_keep_ordered].copy()
# add columns for priority and expanded date
metadata_copy['priority'] = metadata_copy.index.to_series().apply(lambda x: priorities[x])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't use index.map(priorities) here instead of a lambda?

Comment on lines +1463 to +1464
metadata_copy = (pd.concat([metadata_copy, metadata_with_dates[['year','month','day']]], axis=1)
.reindex(metadata_copy.index)) # not needed after pandas>=1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html#index-column-name-preservation-when-aggregating
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to use merge here instead of concat? I always worry about concatenating data frames with an assumed correct ordering of records. A merge (or join on index) is more reassuring because you can provide the explicit keys you expect in both inputs and also enforce merge outcomes with the validate argument.

Maybe this issue relates to the intermediate variable TODO in the line above. If so, we should address that TODO in this PR just to keep this code as clean as possible.

Comment on lines +1501 to +1504
prioritized_metadata['group'] = list(zip(*[prioritized_metadata[col] for col in group_by]))
prioritized_metadata['group_size'] = prioritized_metadata['group'].map(sizes_per_group)
prioritized_metadata['group_cumcount'] = prioritized_metadata.sort_values('priority', ascending=False).groupby(group_by + ['group_size']).cumcount()
prioritized_metadata = prioritized_metadata[prioritized_metadata['group_cumcount'] < prioritized_metadata['group_size']]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this code correctly, this logic will prevent group-by output from exceeding the requested maximum number of sequences when we perform probabilistic sampling. Is that right?

Could you add some introductory comments to this block to explain its general purpose? I assume that the other part of this logic is enforcing the calculated number of sequences per group when all of the preceding code had been allowing the maximum of all sequences per group.

seq_keep,
metadata,
group_by,
if not seq_keep:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything that follows in this metadata loop generally makes sense (although see other inline comments where it doesn't make sense :)). A couple of major points occurred to me, reading through this:

  1. You might consider what the interface for this code would look like if you needed to move it into a function that gets called per loop through the metadata. For example, there is a lot of preprocessing that needs to happen to the metadata copy before the (most important) prioritization logic can run. How much of that setup code could be abstracted to make the purpose of this inner loop clearer?

  2. This section needs some brief introductory comments to setup the general approach to handling grouping and priorities. Future readers of this code may not be as proficient in pandas as you are or may miss the point that we need to limit memory while always prioritizing the top N records per group. A quick intro could help avoid a lot of confusion.

On a smaller aesthetic note, some line returns between logical blocks (usually preceding comments) would also improve readability. Pandas code, in particular, encourages run-on blocks of data frame modifications, but it's ok to let these blocks breathe a bit.

@@ -0,0 +1,123 @@
Integration tests for grouping features in augur filter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding these tests! Several of these new tests potentially overlap with existing tests in filter.t. Would you mind making a quick pass through filter.t to drop any redundant tests there that exist in this file?

Comment on lines +28 to +29
dummy_group = ['_dummy',]
dummy_group_value = ('_dummy',)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor point: can you capitalize these variables throughout? It's a small convention that helps call out global variables elsewhere in the code.

It also looks like we don't use the comment_char any more, so we could delete that line.

if prioritized_metadata.empty:
prioritized_metadata = chunk_prioritized_metadata
else:
prioritized_metadata = prioritized_metadata.append(chunk_prioritized_metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered using concat with a list of data frames here instead of append? I tend to follow the list and concat pattern based on the recommendation in the pandas docs for append, but I don't know how much that change actually affects our performance here.

@victorlin
Copy link
Member Author

victorlin commented Dec 30, 2021

Converting to draft due to outstanding issues and and discussion of an alternative option.

@victorlin victorlin marked this pull request as draft December 30, 2021 20:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Status: Backlog
Development

Successfully merging this pull request may close these issues.

2 participants