diff --git a/.pylintrc b/.pylintrc index 8e8e70392..023aa51ed 100644 --- a/.pylintrc +++ b/.pylintrc @@ -142,7 +142,8 @@ disable=print-statement, multiple-imports, no-else-return, unscriptable-object, - relative-beyond-top-level + relative-beyond-top-level, + no-member # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/augur/filter.py b/augur/filter.py index d263540bc..02c316474 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -8,6 +8,7 @@ import random, os, re import pandas as pd import numpy as np +import operator import sys import datetime from tempfile import NamedTemporaryFile @@ -72,25 +73,328 @@ def read_priority_scores(fname): print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr) raise e -def filter_by_query(strains, metadata, query): - """Filter a set of strains using Pandas DataFrame querying against the metadata file. +# Define metadata filters. + +def filter_by_exclude_all(metadata): + """Exclude all strains regardless of the given metadata content. + + This is a placeholder function that can be called as part of a generalized + loop through all possible functions. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + + Returns + ------- + set[str]: + Empty set of strains + + >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) + >>> filter_by_exclude_all(metadata) + set() + """ + return set() + + +def filter_by_exclude(metadata, excluded_strains): + """Exclude the given set of strains from the given metadata. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + excluded_strains : set[str] + Set of strain names to exclude from the given metadata + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) + >>> filter_by_exclude(metadata, {"strain1"}) + {'strain2'} + """ + return set(metadata.index.values) - excluded_strains + + +def parse_filter_query(query): + """Parse an augur filter-style query and return the corresponding column, + operator, and value for the query. + + Parameters + ---------- + query : str + augur filter-style query following the pattern of `"property=value"` or `"property!=value"` + + Returns + ------- + str : + Name of column to query + callable : + Operator function to test equality or non-equality of values + str : + Value of column to query + + >>> parse_filter_query("property=value") + ('property', , 'value') + >>> parse_filter_query("property!=value") + ('property', , 'value') + + """ + column, value = re.split(r'!?=', query) + op = operator.eq + if "!=" in query: + op = operator.ne + + return column, op, value + + +def filter_by_exclude_where(metadata, exclude_where): + """Exclude all strains from the given metadata that match the given exclusion query. + + Unlike pandas query syntax, exclusion queries should follow the pattern of + `"property=value"` or `"property!=value"`. Additionally, this filter treats + all values like lowercase strings, so we convert all values to strings first + and then lowercase them before testing the given query. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + exclude_where : str + Filter query used to exclude strains + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) + >>> filter_by_exclude_where(metadata, "region!=Europe") + {'strain2'} + >>> filter_by_exclude_where(metadata, "region=Europe") + {'strain1'} + >>> filter_by_exclude_where(metadata, "region=europe") + {'strain1'} + + """ + column, op, value = parse_filter_query(exclude_where) + excluded = op(metadata[column].astype(str).str.lower(), value.lower()) + + # Negate the boolean index of excluded strains to get the index of strains + # that passed the filter. + included = ~excluded + return set(metadata[included].index.values) + + +def filter_by_query(metadata, query): + """Filter metadata in the given pandas DataFrame with a query string and return + the strain names that pass the filter. Parameters ---------- - strains : set[str] - Set of strain names to filter metadata : pandas.DataFrame - Metadata associated with the given strains + Metadata indexed by strain name query : str Query string for the dataframe. Returns ------- set[str]: - Set of strains that match the given query + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) + >>> filter_by_query(metadata, "region == 'Africa'") + {'strain1'} + >>> filter_by_query(metadata, "region == 'North America'") + set() + + """ + return set(metadata.query(query).index.values) + + +def filter_by_ambiguous_date(metadata, date_column="date", ambiguity="any"): + """Filter metadata in the given pandas DataFrame where values in the given date + column have a given level of ambiguity. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + date_column : str + Column in the dataframe with dates. + ambiguity : str + Level of date ambiguity to filter metadata by + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-XX"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) + >>> filter_by_ambiguous_date(metadata) + {'strain2'} + >>> sorted(filter_by_ambiguous_date(metadata, ambiguity="month")) + ['strain1', 'strain2'] + + """ + date_is_ambiguous = metadata[date_column].apply( + lambda date: is_date_ambiguous(date, ambiguity) + ) + return set(metadata[~date_is_ambiguous].index.values) + + +def filter_by_date(metadata, date_column="date", min_date=None, max_date=None): + """Filter metadata by minimum or maximum date. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + date_column : str + Column in the dataframe with dates. + min_date : float + Minimum date + max_date : float + Maximum date + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) + >>> filter_by_date(metadata, min_date=numeric_date("2020-01-02")) + {'strain2'} + >>> filter_by_date(metadata, max_date=numeric_date("2020-01-01")) + {'strain1'} + >>> filter_by_date(metadata, min_date=numeric_date("2020-01-03"), max_date=numeric_date("2020-01-10")) + set() + >>> sorted(filter_by_date(metadata, min_date=numeric_date("2019-12-30"), max_date=numeric_date("2020-01-10"))) + ['strain1', 'strain2'] + >>> sorted(filter_by_date(metadata)) + ['strain1', 'strain2'] + + """ + strains = set(metadata.index.values) + if not min_date and not max_date: + return strains + + dates = get_numerical_dates(metadata, fmt="%Y-%m-%d") + filtered = {strain for strain in strains if dates[strain] is not None} + + if min_date: + filtered = {s for s in filtered if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s]) >= min_date} + + if max_date: + filtered = {s for s in filtered if (np.isscalar(dates[s]) or all(dates[s])) and np.min(dates[s]) <= max_date} + + return filtered + + +def filter_by_sequence_length(metadata, sequence_index, min_length=0): + """Filter metadata by sequence length from a given sequence index. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + sequence_index : pandas.DataFrame + Sequence index + min_length : int + Minimum number of standard nucleotide characters (A, C, G, or T) in each sequence + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) + >>> sequence_index = pd.DataFrame([{"strain": "strain1", "ACGT": 28000}, {"strain": "strain2", "ACGT": 26000}]).set_index("strain") + >>> filter_by_sequence_length(metadata, sequence_index, min_length=27000) + {'strain1'} + + It is possible for the sequence index to be missing strains present in the metadata. + + >>> sequence_index = pd.DataFrame([{"strain": "strain3", "ACGT": 28000}, {"strain": "strain2", "ACGT": 26000}]).set_index("strain") + >>> filter_by_sequence_length(metadata, sequence_index, min_length=27000) + set() + + """ + strains = set(metadata.index.values) + filtered_sequence_index = sequence_index.loc[ + sequence_index.index.intersection(strains) + ] + + return set(filtered_sequence_index[filtered_sequence_index["ACGT"] >= min_length].index.values) + + +def filter_by_non_nucleotide(metadata, sequence_index): + """Filter metadata for strains with invalid nucleotide content. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + sequence_index : pandas.DataFrame + Sequence index + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) + >>> sequence_index = pd.DataFrame([{"strain": "strain1", "invalid_nucleotides": 0}, {"strain": "strain2", "invalid_nucleotides": 1}]).set_index("strain") + >>> filter_by_non_nucleotide(metadata, sequence_index) + {'strain1'} + + """ + strains = set(metadata.index.values) + filtered_sequence_index = sequence_index.loc[ + sequence_index.index.intersection(strains) + ] + no_invalid_nucleotides = filtered_sequence_index["invalid_nucleotides"] == 0 + + return set(filtered_sequence_index[no_invalid_nucleotides].index.values) + + +def include_by_query(metadata, include_where): + """Include all strains from the given metadata that match the given query. + + Unlike pandas query syntax, inclusion queries should follow the pattern of + `"property=value"` or `"property!=value"`. Additionally, this filter treats + all values like lowercase strings, so we convert all values to strings first + and then lowercase them before testing the given query. + + Parameters + ---------- + metadata : pandas.DataFrame + Metadata indexed by strain name + include_where : str + Filter query used to include strains + + Returns + ------- + set[str]: + Strains that pass the filter + + >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) + >>> include_by_query(metadata, "region!=Europe") + {'strain1'} + >>> include_by_query(metadata, "region=Europe") + {'strain2'} + >>> include_by_query(metadata, "region=europe") + {'strain2'} + """ - filtered_strains = set(metadata.query(query).index.values) - return strains & filtered_strains + column, op, value = parse_filter_query(include_where) + included = op(metadata[column].astype(str).str.lower(), value.lower()) + return set(metadata[included].index.values) + def register_arguments(parser): input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered") @@ -231,7 +535,8 @@ def run(args): sequence_index = pd.read_csv( sequence_index_path, - sep="\t" + sep="\t", + index_col="strain", ) # Remove temporary index file, if it exists. @@ -240,7 +545,7 @@ def run(args): # Calculate summary statistics needed for filtering. sequence_index["ACGT"] = sequence_index.loc[:, ["A", "C", "G", "T"]].sum(axis=1) - sequence_strains = set(sequence_index["strain"].values) + sequence_strains = set(sequence_index.index.values) else: sequence_strains = None @@ -270,7 +575,7 @@ def run(args): # Exclude all strains by default. if args.exclude_all: num_excluded_by_all = len(available_strains) - seq_keep = set() + seq_keep = filter_by_exclude_all(metadata) # remove strains explicitly excluded by name # read list of strains to exclude from file and prune seq_keep @@ -278,8 +583,9 @@ def run(args): if args.exclude: try: to_exclude = read_strains(*args.exclude) - num_excluded_by_name = len(seq_keep & to_exclude) - seq_keep = seq_keep - to_exclude + filtered = seq_keep & filter_by_exclude(metadata, to_exclude) + num_excluded_by_name = len(seq_keep - filtered) + seq_keep = filtered except FileNotFoundError as e: print("ERROR: Could not open file of excluded strains '%s'" % args.exclude, file=sys.stderr) sys.exit(1) @@ -290,51 +596,42 @@ def run(args): if args.exclude_where: for ex in args.exclude_where: try: - col, val = re.split(r'!?=', ex) + filtered = seq_keep & filter_by_exclude_where(metadata, ex) + num_excluded_by_metadata[ex] = len(seq_keep - filtered) + seq_keep = filtered except (ValueError,TypeError): + # TODO: this validation should happen at the argparse level and + # throw an error instead of trying to continue filtering with an + # invalid filter query. print("invalid --exclude-where clause \"%s\", should be of from property=value or property!=value"%ex) - else: - to_exclude = set() - for seq_name in seq_keep: - if "!=" in ex: # i.e. property!=value requested - if metadata.loc[seq_name].get(col,'unknown').lower() != val.lower(): - to_exclude.add(seq_name) - else: # i.e. property=value requested - if metadata.loc[seq_name].get(col,'unknown').lower() == val.lower(): - to_exclude.add(seq_name) - - num_excluded_by_metadata[ex] = len(seq_keep & to_exclude) - seq_keep = seq_keep - to_exclude # exclude strains by metadata, using Pandas querying num_excluded_by_query = 0 if args.query: - filtered = filter_by_query(seq_keep, metadata, args.query) + filtered = seq_keep & filter_by_query(metadata, args.query) num_excluded_by_query = len(seq_keep - filtered) seq_keep = filtered # filter by ambiguous dates num_excluded_by_ambiguous_date = 0 if args.exclude_ambiguous_dates_by and 'date' in meta_columns: - seq_keep_by_date = set() - for seq_name in seq_keep: - if not is_date_ambiguous(metadata.loc[seq_name]['date'], args.exclude_ambiguous_dates_by): - seq_keep_by_date.add(seq_name) - - num_excluded_by_ambiguous_date = len(seq_keep) - len(seq_keep_by_date) - seq_keep = seq_keep_by_date + filtered = seq_keep & filter_by_ambiguous_date( + metadata, + ambiguity=args.exclude_ambiguous_dates_by + ) + num_excluded_by_ambiguous_date = len(seq_keep - filtered) + seq_keep = filtered # filter by date num_excluded_by_date = 0 if (args.min_date or args.max_date) and 'date' in meta_columns: - dates = get_numerical_dates(metadata, fmt="%Y-%m-%d") - tmp = {s for s in seq_keep if dates[s] is not None} - if args.min_date: - tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>=args.min_date} - if args.max_date: - tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.min(dates[s])<=args.max_date} - num_excluded_by_date = len(seq_keep) - len(tmp) - seq_keep = tmp + filtered = seq_keep & filter_by_date( + metadata, + min_date=args.min_date, + max_date=args.max_date, + ) + num_excluded_by_date = len(seq_keep - filtered) + seq_keep = filtered # filter by sequence length num_excluded_by_length = 0 @@ -342,31 +639,20 @@ def run(args): if is_vcf: #doesn't make sense for VCF, ignore. print("WARNING: Cannot use min_length for VCF files. Ignoring...") else: - is_in_seq_keep = sequence_index["strain"].isin(seq_keep) - is_gte_min_length = sequence_index["ACGT"] >= args.min_length - - seq_keep_by_length = set( - sequence_index[ - (is_in_seq_keep) & (is_gte_min_length) - ]["strain"].tolist() + filtered = seq_keep & filter_by_sequence_length( + metadata, + sequence_index, + min_length=args.min_length ) - - num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length) - seq_keep = seq_keep_by_length + num_excluded_by_length = len(seq_keep - filtered) + seq_keep = filtered # exclude sequences with non-nucleotide characters num_excluded_by_nuc = 0 if args.non_nucleotide: - is_in_seq_keep = sequence_index["strain"].isin(seq_keep) - no_invalid_nucleotides = sequence_index["invalid_nucleotides"] == 0 - seq_keep_by_valid_nucleotides = set( - sequence_index[ - (is_in_seq_keep) & (no_invalid_nucleotides) - ]["strain"].tolist() - ) - - num_excluded_by_nuc = len(seq_keep) - len(seq_keep_by_valid_nucleotides) - seq_keep = seq_keep_by_valid_nucleotides + filtered = seq_keep & filter_by_non_nucleotide(metadata, sequence_index) + num_excluded_by_nuc = len(seq_keep - filtered) + seq_keep = filtered # subsampling. This will sort sequences into groups by meta data fields # specified in --group-by and then take at most --sequences-per-group @@ -541,18 +827,14 @@ def run(args): num_included_by_metadata = 0 if args.include_where: to_include = set() - for ex in args.include_where: try: - col, val = ex.split("=") + to_include |= include_by_query(metadata, ex) except (ValueError,TypeError): - print("invalid include clause %s, should be of from property=value"%ex) - continue - - # loop over all sequences and re-add sequences - for seq_name in available_strains: - if metadata.loc[seq_name].get(col)==val: - to_include.add(seq_name) + # TODO: this validation should happen at the argparse level and + # throw an error instead of trying to continue filtering with an + # invalid filter query. + print("invalid --include-where clause \"%s\", should be of from property=value or property!=value"%ex) num_included_by_metadata = len(to_include) seq_keep = seq_keep | to_include diff --git a/tests/functional/filter.t b/tests/functional/filter.t index 7b4c29fc2..35afc3d6f 100644 --- a/tests/functional/filter.t +++ b/tests/functional/filter.t @@ -3,6 +3,19 @@ Integration tests for augur filter. $ pushd "$TESTDIR" > /dev/null $ export AUGUR="../../bin/augur" +Filter with exclude query for two regions that comprise all but one strain. +This filter should leave a single record from Oceania. +Force include one South American record by country to get two total records. + + $ ${AUGUR} filter \ + > --metadata filter/metadata.tsv \ + > --exclude-where "region=South America" "region=North America" \ + > --include-where "country=Ecuador" \ + > --output-strains "$TMP/filtered_strains.txt" > /dev/null + $ wc -l "$TMP/filtered_strains.txt" + \s*2 .* (re) + $ rm -f "$TMP/filtered_strains.txt" + Filter with subsampling, requesting 1 sequence per group (for a group with 3 distinct values). $ ${AUGUR} filter \ diff --git a/tests/test_filter.py b/tests/test_filter.py index cd41e69e6..ca884c00f 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -165,19 +165,9 @@ def test_filter_on_query_good(self, tmpdir, sequences): ("SEQ_2","colorado","bad"), ("SEQ_3","nevada","good"))) metadata, columns = read_metadata(meta_fn, as_data_frame=True) - filtered = augur.filter.filter_by_query(set(sequences.keys()), metadata, 'quality=="good"') + filtered = augur.filter.filter_by_query(metadata, 'quality=="good"') assert sorted(filtered) == ["SEQ_1", "SEQ_3"] - def test_filter_on_query_subset(self, tmpdir): - """Test filtering on query works when given fewer strains than metadata""" - meta_fn = write_metadata(tmpdir, (("strain","location","quality"), - ("SEQ_1","colorado","good"), - ("SEQ_2","colorado","bad"), - ("SEQ_3","nevada","good"))) - metadata, columns = read_metadata(meta_fn, as_data_frame=True) - filtered = augur.filter.filter_by_query({"SEQ_2"}, metadata, 'quality=="bad" & location=="colorado"') - assert sorted(filtered) == ["SEQ_2"] - def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser): """Test that filter --query works as expected""" out_fn = str(tmpdir / "out.fasta")