Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable reading metadata as a pandas DataFrame #743

Merged
merged 2 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ disable=print-statement,
missing-docstring,
bad-whitespace,
line-too-long,
invalid-name
invalid-name,
wrong-import-order,
multiple-imports,
no-else-return,
unscriptable-object,
relative-beyond-top-level

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
78 changes: 39 additions & 39 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,25 @@ def read_priority_scores(fname):
print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr)
raise e

def filter_by_query(sequences, metadata_file, query):
"""Filter a set of sequences using Pandas DataFrame querying against the metadata file.
def filter_by_query(strains, metadata, query):
"""Filter a set of strains using Pandas DataFrame querying against the metadata file.

Parameters
----------
sequences : list[str]
List of sequence names to filter
metadata_file : str
Path to the metadata associated wtih the sequences
strains : set[str]
Set of strain names to filter
metadata : pandas.DataFrame
Metadata associated with the given strains
query : str
Query string for the dataframe.

Returns
-------
list[str]:
List of sequence names that match the given query
set[str]:
Set of strains that match the given query
"""
filtered_meta_dict, _ = read_metadata(metadata_file, query)
return [seq for seq in sequences if seq in filtered_meta_dict]
filtered_strains = set(metadata.query(query).index.values)
return strains & filtered_strains

def register_arguments(parser):
input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered")
Expand Down Expand Up @@ -177,8 +177,8 @@ def run(args):
try:
# Metadata are the source of truth for which sequences we want to keep
# in filtered output.
meta_dict, meta_columns = read_metadata(args.metadata)
metadata_strains = set(meta_dict.keys())
metadata, meta_columns = read_metadata(args.metadata, as_data_frame=True)
metadata_strains = set(metadata.index.values)
except ValueError as error:
print("ERROR: Problem reading in {}:".format(args.metadata))
print(error)
Expand Down Expand Up @@ -297,10 +297,10 @@ def run(args):
to_exclude = set()
for seq_name in seq_keep:
if "!=" in ex: # i.e. property!=value requested
if meta_dict[seq_name].get(col,'unknown').lower() != val.lower():
if metadata.loc[seq_name].get(col,'unknown').lower() != val.lower():
to_exclude.add(seq_name)
else: # i.e. property=value requested
if meta_dict[seq_name].get(col,'unknown').lower() == val.lower():
if metadata.loc[seq_name].get(col,'unknown').lower() == val.lower():
to_exclude.add(seq_name)

num_excluded_by_metadata[ex] = len(seq_keep & to_exclude)
Expand All @@ -309,34 +309,16 @@ def run(args):
# exclude strains by metadata, using Pandas querying
num_excluded_by_query = 0
if args.query:
filtered = set(filter_by_query(list(seq_keep), args.metadata, args.query))
filtered = filter_by_query(seq_keep, metadata, args.query)
num_excluded_by_query = len(seq_keep - filtered)
seq_keep = filtered

# filter by sequence length
num_excluded_by_length = 0
if args.min_length:
if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
is_in_seq_keep = sequence_index["strain"].isin(seq_keep)
is_gte_min_length = sequence_index["ACGT"] >= args.min_length

seq_keep_by_length = set(
sequence_index[
(is_in_seq_keep) & (is_gte_min_length)
]["strain"].tolist()
)

num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length)
seq_keep = seq_keep_by_length

# filter by ambiguous dates
num_excluded_by_ambiguous_date = 0
if args.exclude_ambiguous_dates_by and 'date' in meta_columns:
seq_keep_by_date = set()
for seq_name in seq_keep:
if not is_date_ambiguous(meta_dict[seq_name]['date'], args.exclude_ambiguous_dates_by):
if not is_date_ambiguous(metadata.loc[seq_name]['date'], args.exclude_ambiguous_dates_by):
seq_keep_by_date.add(seq_name)

num_excluded_by_ambiguous_date = len(seq_keep) - len(seq_keep_by_date)
Expand All @@ -345,7 +327,7 @@ def run(args):
# filter by date
num_excluded_by_date = 0
if (args.min_date or args.max_date) and 'date' in meta_columns:
dates = get_numerical_dates(meta_dict, fmt="%Y-%m-%d")
dates = get_numerical_dates(metadata, fmt="%Y-%m-%d")
tmp = {s for s in seq_keep if dates[s] is not None}
if args.min_date:
tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>=args.min_date}
Expand All @@ -354,6 +336,24 @@ def run(args):
num_excluded_by_date = len(seq_keep) - len(tmp)
seq_keep = tmp

# filter by sequence length
num_excluded_by_length = 0
if args.min_length:
if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
is_in_seq_keep = sequence_index["strain"].isin(seq_keep)
is_gte_min_length = sequence_index["ACGT"] >= args.min_length

seq_keep_by_length = set(
sequence_index[
(is_in_seq_keep) & (is_gte_min_length)
]["strain"].tolist()
)

num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length)
seq_keep = seq_keep_by_length

# exclude sequences with non-nucleotide characters
num_excluded_by_nuc = 0
if args.non_nucleotide:
Expand Down Expand Up @@ -385,7 +385,7 @@ def run(args):
probabilistic_sampling = False

if args.subsample_max_sequences or (args.group_by and args.sequences_per_group):

#set groups to group_by values
if args.group_by:
groups = args.group_by
Expand All @@ -398,7 +398,7 @@ def run(args):

for seq_name in seq_keep:
group = []
m = meta_dict[seq_name]
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
Expand Down Expand Up @@ -551,7 +551,7 @@ def run(args):

# loop over all sequences and re-add sequences
for seq_name in available_strains:
if meta_dict[seq_name].get(col)==val:
if metadata.loc[seq_name].get(col)==val:
to_include.add(seq_name)

num_included_by_metadata = len(to_include)
Expand Down Expand Up @@ -608,7 +608,7 @@ def run(args):
num_excluded_by_lack_of_sequences = len(metadata_strains - sequence_strains)

if args.output_metadata:
metadata_df = pd.DataFrame([meta_dict[strain] for strain in seq_keep])
metadata_df = metadata.loc[seq_keep]
metadata_df.to_csv(
args.output_metadata,
sep="\t",
Expand Down
11 changes: 8 additions & 3 deletions augur/util_support/metadata_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ class MetadataFile:
which is used to match metadata with samples.
"""

def __init__(self, fname, query=None):
def __init__(self, fname, query=None, as_data_frame=False):
self.fname = fname
self.query = query
self.as_data_frame = as_data_frame

self.key_type = self.find_key_type()

Expand All @@ -26,8 +27,12 @@ def read(self):
# original "strain"/"name" remains in the output.
self.metadata["_index"] = self.metadata[self.key_type]

metadata_dict = self.metadata.set_index("_index").to_dict("index")
return metadata_dict, self.columns
metadata = self.metadata.set_index("_index")

if self.as_data_frame:
return metadata, self.columns
else:
return metadata.to_dict("index"), self.columns

@property
@functools.lru_cache()
Expand Down
74 changes: 55 additions & 19 deletions augur/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import Bio
import Bio.Phylo
from datetime import datetime
import gzip
import os, json, sys
import pandas as pd
Expand Down Expand Up @@ -59,8 +60,8 @@ def get_json_name(args, default=None):
def ambiguous_date_to_date_range(uncertain_date, fmt, min_max_year=None):
return DateDisambiguator(uncertain_date, fmt=fmt, min_max_year=min_max_year).range()

def read_metadata(fname, query=None):
return MetadataFile(fname, query).read()
def read_metadata(fname, query=None, as_data_frame=False):
return MetadataFile(fname, query, as_data_frame).read()

def is_date_ambiguous(date, ambiguous_by="any"):
"""
Expand Down Expand Up @@ -93,28 +94,63 @@ def is_date_ambiguous(date, ambiguous_by="any"):
"X" in day and ambiguous_by in ("any", "day")
))

def get_numerical_date_from_value(value, fmt=None, min_max_year=None, raise_error=True):
if type(value)!=str:
if raise_error:
raise ValueError(value)
else:
numerical_date = None
elif 'XX' in value:
ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year)
if ambig_date is None or None in ambig_date:
numerical_date = [None, None] #don't send to numeric_date or will be set to today
else:
numerical_date = [numeric_date(d) for d in ambig_date]
else:
try:
numerical_date = numeric_date(datetime.strptime(value, fmt))
except:
numerical_date = None

return numerical_date

def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None):
if fmt:
from datetime import datetime
numerical_dates = {}
for k,m in meta_dict.items():
v = m[date_col]
if type(v)!=str:
print("WARNING: %s has an invalid data string:"%k,v)
continue
elif 'XX' in v:
ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year)
if ambig_date is None or None in ambig_date:
numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today
else:
numerical_dates[k] = [numeric_date(d) for d in ambig_date]
else:

if isinstance(meta_dict, dict):
for k,m in meta_dict.items():
v = m[date_col]
try:
numerical_dates[k] = numeric_date(datetime.strptime(v, fmt))
except:
numerical_dates[k] = None
numerical_dates[k] = get_numerical_date_from_value(
v,
fmt,
min_max_year
)
except ValueError:
print(
"WARNING: %s has an invalid data string: %s"% (k, v),
file=sys.stderr
)
continue
elif isinstance(meta_dict, pd.DataFrame):
strains = meta_dict.index.values
dates = meta_dict[date_col].apply(
lambda date: get_numerical_date_from_value(
date,
fmt,
min_max_year,
raise_error=False
)
).values
numerical_dates = dict(zip(strains, dates))
else:
numerical_dates = {k:float(v) for k,v in meta_dict.items()}
if isinstance(meta_dict, dict):
numerical_dates = {k:float(v) for k,v in meta_dict.items()}
elif isinstance(meta_dict, pd.DataFrame):
strains = meta_dict.index.values
dates = meta_dict[date_col].astype(float)
numerical_dates = dict(zip(strains, dates))

return numerical_dates

Expand Down
13 changes: 8 additions & 5 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from Bio.SeqRecord import SeqRecord

import augur.filter
from augur.utils import read_metadata

@pytest.fixture
def argparser():
Expand Down Expand Up @@ -163,17 +164,19 @@ def test_filter_on_query_good(self, tmpdir, sequences):
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
filtered = augur.filter.filter_by_query(sequences.keys(), meta_fn, 'quality=="good"')
assert filtered == ["SEQ_1", "SEQ_3"]
metadata, columns = read_metadata(meta_fn, as_data_frame=True)
filtered = augur.filter.filter_by_query(set(sequences.keys()), metadata, 'quality=="good"')
assert sorted(filtered) == ["SEQ_1", "SEQ_3"]

def test_filter_on_query_subset(self, tmpdir):
"""Test filtering on query works when given fewer strains than metadata"""
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
filtered = augur.filter.filter_by_query(["SEQ_2"], meta_fn, 'quality=="bad" & location=="colorado"')
assert filtered == ["SEQ_2"]
metadata, columns = read_metadata(meta_fn, as_data_frame=True)
filtered = augur.filter.filter_by_query({"SEQ_2"}, metadata, 'quality=="bad" & location=="colorado"')
assert sorted(filtered) == ["SEQ_2"]

def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser):
"""Test that filter --query works as expected"""
Expand Down Expand Up @@ -242,4 +245,4 @@ def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser):
% (fasta_fn, meta_fn, out_fn, max_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]