From 43e6ded49b745e158d74b14466711ae38673ad4d Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Sat, 26 Jun 2021 21:11:21 -0700 Subject: [PATCH 1/2] Use DataFrame for metadata instead of dict Allows metadata to be returned as a pandas DataFrame instead of a dictionary and uses this API to load metadata as a DataFrame in augur filter. Updates the corresponding metadata references in augur filter and related function calls to work with this new data structure. This change allows us to eventually switch to chunking metadata input to avoid reading it all into memory. --- augur/filter.py | 78 ++++++++++++++--------------- augur/util_support/metadata_file.py | 11 ++-- augur/utils.py | 74 ++++++++++++++++++++------- tests/test_filter.py | 13 +++-- 4 files changed, 110 insertions(+), 66 deletions(-) diff --git a/augur/filter.py b/augur/filter.py index 10fdaa68f..d263540bc 100644 --- a/augur/filter.py +++ b/augur/filter.py @@ -72,25 +72,25 @@ def read_priority_scores(fname): print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr) raise e -def filter_by_query(sequences, metadata_file, query): - """Filter a set of sequences using Pandas DataFrame querying against the metadata file. +def filter_by_query(strains, metadata, query): + """Filter a set of strains using Pandas DataFrame querying against the metadata file. Parameters ---------- - sequences : list[str] - List of sequence names to filter - metadata_file : str - Path to the metadata associated wtih the sequences + strains : set[str] + Set of strain names to filter + metadata : pandas.DataFrame + Metadata associated with the given strains query : str Query string for the dataframe. Returns ------- - list[str]: - List of sequence names that match the given query + set[str]: + Set of strains that match the given query """ - filtered_meta_dict, _ = read_metadata(metadata_file, query) - return [seq for seq in sequences if seq in filtered_meta_dict] + filtered_strains = set(metadata.query(query).index.values) + return strains & filtered_strains def register_arguments(parser): input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered") @@ -177,8 +177,8 @@ def run(args): try: # Metadata are the source of truth for which sequences we want to keep # in filtered output. - meta_dict, meta_columns = read_metadata(args.metadata) - metadata_strains = set(meta_dict.keys()) + metadata, meta_columns = read_metadata(args.metadata, as_data_frame=True) + metadata_strains = set(metadata.index.values) except ValueError as error: print("ERROR: Problem reading in {}:".format(args.metadata)) print(error) @@ -297,10 +297,10 @@ def run(args): to_exclude = set() for seq_name in seq_keep: if "!=" in ex: # i.e. property!=value requested - if meta_dict[seq_name].get(col,'unknown').lower() != val.lower(): + if metadata.loc[seq_name].get(col,'unknown').lower() != val.lower(): to_exclude.add(seq_name) else: # i.e. property=value requested - if meta_dict[seq_name].get(col,'unknown').lower() == val.lower(): + if metadata.loc[seq_name].get(col,'unknown').lower() == val.lower(): to_exclude.add(seq_name) num_excluded_by_metadata[ex] = len(seq_keep & to_exclude) @@ -309,34 +309,16 @@ def run(args): # exclude strains by metadata, using Pandas querying num_excluded_by_query = 0 if args.query: - filtered = set(filter_by_query(list(seq_keep), args.metadata, args.query)) + filtered = filter_by_query(seq_keep, metadata, args.query) num_excluded_by_query = len(seq_keep - filtered) seq_keep = filtered - # filter by sequence length - num_excluded_by_length = 0 - if args.min_length: - if is_vcf: #doesn't make sense for VCF, ignore. - print("WARNING: Cannot use min_length for VCF files. Ignoring...") - else: - is_in_seq_keep = sequence_index["strain"].isin(seq_keep) - is_gte_min_length = sequence_index["ACGT"] >= args.min_length - - seq_keep_by_length = set( - sequence_index[ - (is_in_seq_keep) & (is_gte_min_length) - ]["strain"].tolist() - ) - - num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length) - seq_keep = seq_keep_by_length - # filter by ambiguous dates num_excluded_by_ambiguous_date = 0 if args.exclude_ambiguous_dates_by and 'date' in meta_columns: seq_keep_by_date = set() for seq_name in seq_keep: - if not is_date_ambiguous(meta_dict[seq_name]['date'], args.exclude_ambiguous_dates_by): + if not is_date_ambiguous(metadata.loc[seq_name]['date'], args.exclude_ambiguous_dates_by): seq_keep_by_date.add(seq_name) num_excluded_by_ambiguous_date = len(seq_keep) - len(seq_keep_by_date) @@ -345,7 +327,7 @@ def run(args): # filter by date num_excluded_by_date = 0 if (args.min_date or args.max_date) and 'date' in meta_columns: - dates = get_numerical_dates(meta_dict, fmt="%Y-%m-%d") + dates = get_numerical_dates(metadata, fmt="%Y-%m-%d") tmp = {s for s in seq_keep if dates[s] is not None} if args.min_date: tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>=args.min_date} @@ -354,6 +336,24 @@ def run(args): num_excluded_by_date = len(seq_keep) - len(tmp) seq_keep = tmp + # filter by sequence length + num_excluded_by_length = 0 + if args.min_length: + if is_vcf: #doesn't make sense for VCF, ignore. + print("WARNING: Cannot use min_length for VCF files. Ignoring...") + else: + is_in_seq_keep = sequence_index["strain"].isin(seq_keep) + is_gte_min_length = sequence_index["ACGT"] >= args.min_length + + seq_keep_by_length = set( + sequence_index[ + (is_in_seq_keep) & (is_gte_min_length) + ]["strain"].tolist() + ) + + num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length) + seq_keep = seq_keep_by_length + # exclude sequences with non-nucleotide characters num_excluded_by_nuc = 0 if args.non_nucleotide: @@ -385,7 +385,7 @@ def run(args): probabilistic_sampling = False if args.subsample_max_sequences or (args.group_by and args.sequences_per_group): - + #set groups to group_by values if args.group_by: groups = args.group_by @@ -398,7 +398,7 @@ def run(args): for seq_name in seq_keep: group = [] - m = meta_dict[seq_name] + m = metadata.loc[seq_name].to_dict() # collect group specifiers for c in groups: if c == "_dummy": @@ -551,7 +551,7 @@ def run(args): # loop over all sequences and re-add sequences for seq_name in available_strains: - if meta_dict[seq_name].get(col)==val: + if metadata.loc[seq_name].get(col)==val: to_include.add(seq_name) num_included_by_metadata = len(to_include) @@ -608,7 +608,7 @@ def run(args): num_excluded_by_lack_of_sequences = len(metadata_strains - sequence_strains) if args.output_metadata: - metadata_df = pd.DataFrame([meta_dict[strain] for strain in seq_keep]) + metadata_df = metadata.loc[seq_keep] metadata_df.to_csv( args.output_metadata, sep="\t", diff --git a/augur/util_support/metadata_file.py b/augur/util_support/metadata_file.py index 7c7508148..43a963317 100644 --- a/augur/util_support/metadata_file.py +++ b/augur/util_support/metadata_file.py @@ -11,9 +11,10 @@ class MetadataFile: which is used to match metadata with samples. """ - def __init__(self, fname, query=None): + def __init__(self, fname, query=None, as_data_frame=False): self.fname = fname self.query = query + self.as_data_frame = as_data_frame self.key_type = self.find_key_type() @@ -26,8 +27,12 @@ def read(self): # original "strain"/"name" remains in the output. self.metadata["_index"] = self.metadata[self.key_type] - metadata_dict = self.metadata.set_index("_index").to_dict("index") - return metadata_dict, self.columns + metadata = self.metadata.set_index("_index") + + if self.as_data_frame: + return metadata, self.columns + else: + return metadata.to_dict("index"), self.columns @property @functools.lru_cache() diff --git a/augur/utils.py b/augur/utils.py index b025ab9d0..87a0bbea6 100644 --- a/augur/utils.py +++ b/augur/utils.py @@ -1,6 +1,7 @@ import argparse import Bio import Bio.Phylo +from datetime import datetime import gzip import os, json, sys import pandas as pd @@ -59,8 +60,8 @@ def get_json_name(args, default=None): def ambiguous_date_to_date_range(uncertain_date, fmt, min_max_year=None): return DateDisambiguator(uncertain_date, fmt=fmt, min_max_year=min_max_year).range() -def read_metadata(fname, query=None): - return MetadataFile(fname, query).read() +def read_metadata(fname, query=None, as_data_frame=False): + return MetadataFile(fname, query, as_data_frame).read() def is_date_ambiguous(date, ambiguous_by="any"): """ @@ -93,28 +94,63 @@ def is_date_ambiguous(date, ambiguous_by="any"): "X" in day and ambiguous_by in ("any", "day") )) +def get_numerical_date_from_value(value, fmt=None, min_max_year=None, raise_error=True): + if type(value)!=str: + if raise_error: + raise ValueError(value) + else: + numerical_date = None + elif 'XX' in value: + ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year) + if ambig_date is None or None in ambig_date: + numerical_date = [None, None] #don't send to numeric_date or will be set to today + else: + numerical_date = [numeric_date(d) for d in ambig_date] + else: + try: + numerical_date = numeric_date(datetime.strptime(value, fmt)) + except: + numerical_date = None + + return numerical_date + def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None): if fmt: - from datetime import datetime numerical_dates = {} - for k,m in meta_dict.items(): - v = m[date_col] - if type(v)!=str: - print("WARNING: %s has an invalid data string:"%k,v) - continue - elif 'XX' in v: - ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year) - if ambig_date is None or None in ambig_date: - numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today - else: - numerical_dates[k] = [numeric_date(d) for d in ambig_date] - else: + + if isinstance(meta_dict, dict): + for k,m in meta_dict.items(): + v = m[date_col] try: - numerical_dates[k] = numeric_date(datetime.strptime(v, fmt)) - except: - numerical_dates[k] = None + numerical_dates[k] = get_numerical_date_from_value( + v, + fmt, + min_max_year + ) + except ValueError: + print( + "WARNING: %s has an invalid data string: %s"% (k, v), + file=sys.stderr + ) + continue + elif isinstance(meta_dict, pd.DataFrame): + strains = meta_dict.index.values + dates = meta_dict[date_col].apply( + lambda date: get_numerical_date_from_value( + date, + fmt, + min_max_year, + raise_error=False + ) + ).values + numerical_dates = dict(zip(strains, dates)) else: - numerical_dates = {k:float(v) for k,v in meta_dict.items()} + if isinstance(meta_dict, dict): + numerical_dates = {k:float(v) for k,v in meta_dict.items()} + elif isinstance(meta_dict, pd.DataFrame): + strains = meta_dict.index.values + dates = meta_dict[date_col].astype(float) + numerical_dates = dict(zip(strains, dates)) return numerical_dates diff --git a/tests/test_filter.py b/tests/test_filter.py index 106617e46..cd41e69e6 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -9,6 +9,7 @@ from Bio.SeqRecord import SeqRecord import augur.filter +from augur.utils import read_metadata @pytest.fixture def argparser(): @@ -163,8 +164,9 @@ def test_filter_on_query_good(self, tmpdir, sequences): ("SEQ_1","colorado","good"), ("SEQ_2","colorado","bad"), ("SEQ_3","nevada","good"))) - filtered = augur.filter.filter_by_query(sequences.keys(), meta_fn, 'quality=="good"') - assert filtered == ["SEQ_1", "SEQ_3"] + metadata, columns = read_metadata(meta_fn, as_data_frame=True) + filtered = augur.filter.filter_by_query(set(sequences.keys()), metadata, 'quality=="good"') + assert sorted(filtered) == ["SEQ_1", "SEQ_3"] def test_filter_on_query_subset(self, tmpdir): """Test filtering on query works when given fewer strains than metadata""" @@ -172,8 +174,9 @@ def test_filter_on_query_subset(self, tmpdir): ("SEQ_1","colorado","good"), ("SEQ_2","colorado","bad"), ("SEQ_3","nevada","good"))) - filtered = augur.filter.filter_by_query(["SEQ_2"], meta_fn, 'quality=="bad" & location=="colorado"') - assert filtered == ["SEQ_2"] + metadata, columns = read_metadata(meta_fn, as_data_frame=True) + filtered = augur.filter.filter_by_query({"SEQ_2"}, metadata, 'quality=="bad" & location=="colorado"') + assert sorted(filtered) == ["SEQ_2"] def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser): """Test that filter --query works as expected""" @@ -242,4 +245,4 @@ def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser): % (fasta_fn, meta_fn, out_fn, max_date)) augur.filter.run(args) output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta")) - assert list(output.keys()) == ["SEQ_1", "SEQ_2"] \ No newline at end of file + assert list(output.keys()) == ["SEQ_1", "SEQ_2"] From bb81de2d0e98db1825c59917cc3519d142b3561d Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Tue, 6 Jul 2021 12:14:36 -0700 Subject: [PATCH 2/2] Ignore spurious pylint warnings pylint complains about several stylistic issues with our code that are always low priority (e.g., import order). This commit ignores some of the most annoying warnings, so built-in linters (in emacs, etc.) are more useful for catching real issues. --- .pylintrc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index b5f93e79d..8e8e70392 100644 --- a/.pylintrc +++ b/.pylintrc @@ -137,7 +137,12 @@ disable=print-statement, missing-docstring, bad-whitespace, line-too-long, - invalid-name + invalid-name, + wrong-import-order, + multiple-imports, + no-else-return, + unscriptable-object, + relative-beyond-top-level # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option