Skip to content

Commit

Permalink
update tests for new value returned from get_groups_for_subsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Dec 11, 2021
1 parent 9ac13ea commit 6296d8e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
6 changes: 6 additions & 0 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,12 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}]
Distinct groups are returned as a set.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["region"])
>>> groups
{('Africa',), ('Europe',)}
"""
metadata = metadata.loc[strains]
group_values = set()
Expand Down
18 changes: 9 additions & 9 deletions tests/test_filter_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestFilterGroupBy:
def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
strains = ['SEQ_1', 'SEQ_3', 'SEQ_5']
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
assert group_by_strain == {
'SEQ_1': ('_dummy',),
'SEQ_3': ('_dummy',),
Expand All @@ -29,7 +29,7 @@ def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame):
def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
assert group_by_strain == {
'SEQ_1': ('_dummy',),
'SEQ_2': ('_dummy',),
Expand All @@ -51,7 +51,7 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys)
groups = ['country', 'year', 'month', 'invalid']
metadata = valid_metadata.copy()
strains = metadata.index.tolist()
groups, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, _ = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1), 'unknown'),
'SEQ_2': ('A', 2020, (2020, 2), 'unknown'),
Expand All @@ -67,7 +67,7 @@ def test_filter_groupby_skip_ambiguous_year(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "XXXX-02-01"
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -81,7 +81,7 @@ def test_filter_groupby_skip_missing_date(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = None
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -95,7 +95,7 @@ def test_filter_groupby_skip_ambiguous_month(self, valid_metadata: pd.DataFrame)
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "2020-XX-01"
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand All @@ -109,7 +109,7 @@ def test_filter_groupby_skip_missing_month(self, valid_metadata: pd.DataFrame):
metadata = valid_metadata.copy()
metadata.at["SEQ_2", "date"] = "2020"
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, (2020, 1)),
'SEQ_3': ('B', 2020, (2020, 3)),
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, ca
metadata = valid_metadata.copy()
metadata = metadata.drop('date', axis='columns')
strains = metadata.index.tolist()
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 'unknown', 'unknown'),
'SEQ_2': ('A', 'unknown', 'unknown'),
Expand All @@ -166,7 +166,7 @@ def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame):
groups = ['country', 'year', 'month']
metadata = valid_metadata.copy()
strains = []
groups, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
_, group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {}
assert skipped_strains == []

Expand Down

0 comments on commit 6296d8e

Please sign in to comment.