diff --git a/.pylintrc b/.pylintrc index b5f93e79d..8e8e70392 100644 --- a/.pylintrc +++ b/.pylintrc @@ -137,7 +137,12 @@ disable=print-statement, missing-docstring, bad-whitespace, line-too-long, - invalid-name + invalid-name, + wrong-import-order, + multiple-imports, + no-else-return, + unscriptable-object, + relative-beyond-top-level # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/augur/filter.py b/augur/filter.py index 10fdaa68f..d263540bc 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -72,25 +72,25 @@ def read_priority_scores(fname): print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr) raise e -def filter_by_query(sequences, metadata_file, query): - """Filter a set of sequences using Pandas DataFrame querying against the metadata file. +def filter_by_query(strains, metadata, query): + """Filter a set of strains using Pandas DataFrame querying against the metadata file. Parameters ---------- - sequences : list[str] - List of sequence names to filter - metadata_file : str - Path to the metadata associated wtih the sequences + strains : set[str] + Set of strain names to filter + metadata : pandas.DataFrame + Metadata associated with the given strains query : str Query string for the dataframe. Returns ------- - list[str]: - List of sequence names that match the given query + set[str]: + Set of strains that match the given query """ - filtered_meta_dict, _ = read_metadata(metadata_file, query) - return [seq for seq in sequences if seq in filtered_meta_dict] + filtered_strains = set(metadata.query(query).index.values) + return strains & filtered_strains def register_arguments(parser): input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered") @@ -177,8 +177,8 @@ def run(args): try: # Metadata are the source of truth for which sequences we want to keep # in filtered output. - meta_dict, meta_columns = read_metadata(args.metadata) - metadata_strains = set(meta_dict.keys()) + metadata, meta_columns = read_metadata(args.metadata, as_data_frame=True) + metadata_strains = set(metadata.index.values) except ValueError as error: print("ERROR: Problem reading in {}:".format(args.metadata)) print(error) @@ -297,10 +297,10 @@ def run(args): to_exclude = set() for seq_name in seq_keep: if "!=" in ex: # i.e. property!=value requested - if meta_dict[seq_name].get(col,'unknown').lower() != val.lower(): + if metadata.loc[seq_name].get(col,'unknown').lower() != val.lower(): to_exclude.add(seq_name) else: # i.e. property=value requested - if meta_dict[seq_name].get(col,'unknown').lower() == val.lower(): + if metadata.loc[seq_name].get(col,'unknown').lower() == val.lower(): to_exclude.add(seq_name) num_excluded_by_metadata[ex] = len(seq_keep & to_exclude) @@ -309,34 +309,16 @@ def run(args): # exclude strains by metadata, using Pandas querying num_excluded_by_query = 0 if args.query: - filtered = set(filter_by_query(list(seq_keep), args.metadata, args.query)) + filtered = filter_by_query(seq_keep, metadata, args.query) num_excluded_by_query = len(seq_keep - filtered) seq_keep = filtered - # filter by sequence length - num_excluded_by_length = 0 - if args.min_length: - if is_vcf: #doesn't make sense for VCF, ignore. - print("WARNING: Cannot use min_length for VCF files. Ignoring...") - else: - is_in_seq_keep = sequence_index["strain"].isin(seq_keep) - is_gte_min_length = sequence_index["ACGT"] >= args.min_length - - seq_keep_by_length = set( - sequence_index[ - (is_in_seq_keep) & (is_gte_min_length) - ]["strain"].tolist() - ) - - num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length) - seq_keep = seq_keep_by_length - # filter by ambiguous dates num_excluded_by_ambiguous_date = 0 if args.exclude_ambiguous_dates_by and 'date' in meta_columns: seq_keep_by_date = set() for seq_name in seq_keep: - if not is_date_ambiguous(meta_dict[seq_name]['date'], args.exclude_ambiguous_dates_by): + if not is_date_ambiguous(metadata.loc[seq_name]['date'], args.exclude_ambiguous_dates_by): seq_keep_by_date.add(seq_name) num_excluded_by_ambiguous_date = len(seq_keep) - len(seq_keep_by_date) @@ -345,7 +327,7 @@ def run(args): # filter by date num_excluded_by_date = 0 if (args.min_date or args.max_date) and 'date' in meta_columns: - dates = get_numerical_dates(meta_dict, fmt="%Y-%m-%d") + dates = get_numerical_dates(metadata, fmt="%Y-%m-%d") tmp = {s for s in seq_keep if dates[s] is not None} if args.min_date: tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>=args.min_date} @@ -354,6 +336,24 @@ def run(args): num_excluded_by_date = len(seq_keep) - len(tmp) seq_keep = tmp + # filter by sequence length + num_excluded_by_length = 0 + if args.min_length: + if is_vcf: #doesn't make sense for VCF, ignore. + print("WARNING: Cannot use min_length for VCF files. Ignoring...") + else: + is_in_seq_keep = sequence_index["strain"].isin(seq_keep) + is_gte_min_length = sequence_index["ACGT"] >= args.min_length + + seq_keep_by_length = set( + sequence_index[ + (is_in_seq_keep) & (is_gte_min_length) + ]["strain"].tolist() + ) + + num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length) + seq_keep = seq_keep_by_length + # exclude sequences with non-nucleotide characters num_excluded_by_nuc = 0 if args.non_nucleotide: @@ -385,7 +385,7 @@ def run(args): probabilistic_sampling = False if args.subsample_max_sequences or (args.group_by and args.sequences_per_group): - + #set groups to group_by values if args.group_by: groups = args.group_by @@ -398,7 +398,7 @@ def run(args): for seq_name in seq_keep: group = [] - m = meta_dict[seq_name] + m = metadata.loc[seq_name].to_dict() # collect group specifiers for c in groups: if c == "_dummy": @@ -551,7 +551,7 @@ def run(args): # loop over all sequences and re-add sequences for seq_name in available_strains: - if meta_dict[seq_name].get(col)==val: + if metadata.loc[seq_name].get(col)==val: to_include.add(seq_name) num_included_by_metadata = len(to_include) @@ -608,7 +608,7 @@ def run(args): num_excluded_by_lack_of_sequences = len(metadata_strains - sequence_strains) if args.output_metadata: - metadata_df = pd.DataFrame([meta_dict[strain] for strain in seq_keep]) + metadata_df = metadata.loc[seq_keep] metadata_df.to_csv( args.output_metadata, sep="\t", diff --git a/augur/util_support/metadata_file.py b/augur/util_support/metadata_file.py index 7c7508148..43a963317 100644 --- a/augur/util_support/metadata_file.py +++ b/augur/util_support/metadata_file.py @@ -11,9 +11,10 @@ class MetadataFile: which is used to match metadata with samples. """ - def __init__(self, fname, query=None): + def __init__(self, fname, query=None, as_data_frame=False): self.fname = fname self.query = query + self.as_data_frame = as_data_frame self.key_type = self.find_key_type() @@ -26,8 +27,12 @@ def read(self): # original "strain"/"name" remains in the output. self.metadata["_index"] = self.metadata[self.key_type] - metadata_dict = self.metadata.set_index("_index").to_dict("index") - return metadata_dict, self.columns + metadata = self.metadata.set_index("_index") + + if self.as_data_frame: + return metadata, self.columns + else: + return metadata.to_dict("index"), self.columns @property @functools.lru_cache() diff --git a/augur/utils.py b/augur/utils.py index b025ab9d0..87a0bbea6 100644 --- a/augur/utils.py +++ b/augur/utils.py @@ -1,6 +1,7 @@ import argparse import Bio import Bio.Phylo +from datetime import datetime import gzip import os, json, sys import pandas as pd @@ -59,8 +60,8 @@ def get_json_name(args, default=None): def ambiguous_date_to_date_range(uncertain_date, fmt, min_max_year=None): return DateDisambiguator(uncertain_date, fmt=fmt, min_max_year=min_max_year).range() -def read_metadata(fname, query=None): - return MetadataFile(fname, query).read() +def read_metadata(fname, query=None, as_data_frame=False): + return MetadataFile(fname, query, as_data_frame).read() def is_date_ambiguous(date, ambiguous_by="any"): """ @@ -93,28 +94,63 @@ def is_date_ambiguous(date, ambiguous_by="any"): "X" in day and ambiguous_by in ("any", "day") )) +def get_numerical_date_from_value(value, fmt=None, min_max_year=None, raise_error=True): + if type(value)!=str: + if raise_error: + raise ValueError(value) + else: + numerical_date = None + elif 'XX' in value: + ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year) + if ambig_date is None or None in ambig_date: + numerical_date = [None, None] #don't send to numeric_date or will be set to today + else: + numerical_date = [numeric_date(d) for d in ambig_date] + else: + try: + numerical_date = numeric_date(datetime.strptime(value, fmt)) + except: + numerical_date = None + + return numerical_date + def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None): if fmt: - from datetime import datetime numerical_dates = {} - for k,m in meta_dict.items(): - v = m[date_col] - if type(v)!=str: - print("WARNING: %s has an invalid data string:"%k,v) - continue - elif 'XX' in v: - ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year) - if ambig_date is None or None in ambig_date: - numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today - else: - numerical_dates[k] = [numeric_date(d) for d in ambig_date] - else: + + if isinstance(meta_dict, dict): + for k,m in meta_dict.items(): + v = m[date_col] try: - numerical_dates[k] = numeric_date(datetime.strptime(v, fmt)) - except: - numerical_dates[k] = None + numerical_dates[k] = get_numerical_date_from_value( + v, + fmt, + min_max_year + ) + except ValueError: + print( + "WARNING: %s has an invalid data string: %s"% (k, v), + file=sys.stderr + ) + continue + elif isinstance(meta_dict, pd.DataFrame): + strains = meta_dict.index.values + dates = meta_dict[date_col].apply( + lambda date: get_numerical_date_from_value( + date, + fmt, + min_max_year, + raise_error=False + ) + ).values + numerical_dates = dict(zip(strains, dates)) else: - numerical_dates = {k:float(v) for k,v in meta_dict.items()} + if isinstance(meta_dict, dict): + numerical_dates = {k:float(v) for k,v in meta_dict.items()} + elif isinstance(meta_dict, pd.DataFrame): + strains = meta_dict.index.values + dates = meta_dict[date_col].astype(float) + numerical_dates = dict(zip(strains, dates)) return numerical_dates diff --git a/tests/test_filter.py b/tests/test_filter.py index 106617e46..cd41e69e6 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -9,6 +9,7 @@ from Bio.SeqRecord import SeqRecord import augur.filter +from augur.utils import read_metadata @pytest.fixture def argparser(): @@ -163,8 +164,9 @@ def test_filter_on_query_good(self, tmpdir, sequences): ("SEQ_1","colorado","good"), ("SEQ_2","colorado","bad"), ("SEQ_3","nevada","good"))) - filtered = augur.filter.filter_by_query(sequences.keys(), meta_fn, 'quality=="good"') - assert filtered == ["SEQ_1", "SEQ_3"] + metadata, columns = read_metadata(meta_fn, as_data_frame=True) + filtered = augur.filter.filter_by_query(set(sequences.keys()), metadata, 'quality=="good"') + assert sorted(filtered) == ["SEQ_1", "SEQ_3"] def test_filter_on_query_subset(self, tmpdir): """Test filtering on query works when given fewer strains than metadata""" @@ -172,8 +174,9 @@ def test_filter_on_query_subset(self, tmpdir): ("SEQ_1","colorado","good"), ("SEQ_2","colorado","bad"), ("SEQ_3","nevada","good"))) - filtered = augur.filter.filter_by_query(["SEQ_2"], meta_fn, 'quality=="bad" & location=="colorado"') - assert filtered == ["SEQ_2"] + metadata, columns = read_metadata(meta_fn, as_data_frame=True) + filtered = augur.filter.filter_by_query({"SEQ_2"}, metadata, 'quality=="bad" & location=="colorado"') + assert sorted(filtered) == ["SEQ_2"] def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser): """Test that filter --query works as expected""" @@ -242,4 +245,4 @@ def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser): % (fasta_fn, meta_fn, out_fn, max_date)) augur.filter.run(args) output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) - assert list(output.keys()) == ["SEQ_1", "SEQ_2"] \ No newline at end of file + assert list(output.keys()) == ["SEQ_1", "SEQ_2"]