Skip to content

Commit

Permalink
Use DataFrame for metadata instead of dict
Browse files Browse the repository at this point in the history
Allows metadata to be returned as a pandas DataFrame instead of a
dictionary and uses this API to load metadata as a DataFrame in augur
filter. Updates the corresponding metadata references in augur filter
and related function calls to work with this new data structure.

This change allows us to eventually switch to chunking metadata input to
avoid reading it all into memory.
  • Loading branch information
huddlej committed Jul 15, 2021
1 parent 89bc028 commit 43e6ded
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 66 deletions.
78 changes: 39 additions & 39 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,25 @@ def read_priority_scores(fname):
print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr)
raise e

def filter_by_query(sequences, metadata_file, query):
"""Filter a set of sequences using Pandas DataFrame querying against the metadata file.
def filter_by_query(strains, metadata, query):
"""Filter a set of strains using Pandas DataFrame querying against the metadata file.
Parameters
----------
sequences : list[str]
List of sequence names to filter
metadata_file : str
Path to the metadata associated wtih the sequences
strains : set[str]
Set of strain names to filter
metadata : pandas.DataFrame
Metadata associated with the given strains
query : str
Query string for the dataframe.
Returns
-------
list[str]:
List of sequence names that match the given query
set[str]:
Set of strains that match the given query
"""
filtered_meta_dict, _ = read_metadata(metadata_file, query)
return [seq for seq in sequences if seq in filtered_meta_dict]
filtered_strains = set(metadata.query(query).index.values)
return strains & filtered_strains

def register_arguments(parser):
input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered")
Expand Down Expand Up @@ -177,8 +177,8 @@ def run(args):
try:
# Metadata are the source of truth for which sequences we want to keep
# in filtered output.
meta_dict, meta_columns = read_metadata(args.metadata)
metadata_strains = set(meta_dict.keys())
metadata, meta_columns = read_metadata(args.metadata, as_data_frame=True)
metadata_strains = set(metadata.index.values)
except ValueError as error:
print("ERROR: Problem reading in {}:".format(args.metadata))
print(error)
Expand Down Expand Up @@ -297,10 +297,10 @@ def run(args):
to_exclude = set()
for seq_name in seq_keep:
if "!=" in ex: # i.e. property!=value requested
if meta_dict[seq_name].get(col,'unknown').lower() != val.lower():
if metadata.loc[seq_name].get(col,'unknown').lower() != val.lower():
to_exclude.add(seq_name)
else: # i.e. property=value requested
if meta_dict[seq_name].get(col,'unknown').lower() == val.lower():
if metadata.loc[seq_name].get(col,'unknown').lower() == val.lower():
to_exclude.add(seq_name)

num_excluded_by_metadata[ex] = len(seq_keep & to_exclude)
Expand All @@ -309,34 +309,16 @@ def run(args):
# exclude strains by metadata, using Pandas querying
num_excluded_by_query = 0
if args.query:
filtered = set(filter_by_query(list(seq_keep), args.metadata, args.query))
filtered = filter_by_query(seq_keep, metadata, args.query)
num_excluded_by_query = len(seq_keep - filtered)
seq_keep = filtered

# filter by sequence length
num_excluded_by_length = 0
if args.min_length:
if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
is_in_seq_keep = sequence_index["strain"].isin(seq_keep)
is_gte_min_length = sequence_index["ACGT"] >= args.min_length

seq_keep_by_length = set(
sequence_index[
(is_in_seq_keep) & (is_gte_min_length)
]["strain"].tolist()
)

num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length)
seq_keep = seq_keep_by_length

# filter by ambiguous dates
num_excluded_by_ambiguous_date = 0
if args.exclude_ambiguous_dates_by and 'date' in meta_columns:
seq_keep_by_date = set()
for seq_name in seq_keep:
if not is_date_ambiguous(meta_dict[seq_name]['date'], args.exclude_ambiguous_dates_by):
if not is_date_ambiguous(metadata.loc[seq_name]['date'], args.exclude_ambiguous_dates_by):
seq_keep_by_date.add(seq_name)

num_excluded_by_ambiguous_date = len(seq_keep) - len(seq_keep_by_date)
Expand All @@ -345,7 +327,7 @@ def run(args):
# filter by date
num_excluded_by_date = 0
if (args.min_date or args.max_date) and 'date' in meta_columns:
dates = get_numerical_dates(meta_dict, fmt="%Y-%m-%d")
dates = get_numerical_dates(metadata, fmt="%Y-%m-%d")
tmp = {s for s in seq_keep if dates[s] is not None}
if args.min_date:
tmp = {s for s in tmp if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s])>=args.min_date}
Expand All @@ -354,6 +336,24 @@ def run(args):
num_excluded_by_date = len(seq_keep) - len(tmp)
seq_keep = tmp

# filter by sequence length
num_excluded_by_length = 0
if args.min_length:
if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
is_in_seq_keep = sequence_index["strain"].isin(seq_keep)
is_gte_min_length = sequence_index["ACGT"] >= args.min_length

seq_keep_by_length = set(
sequence_index[
(is_in_seq_keep) & (is_gte_min_length)
]["strain"].tolist()
)

num_excluded_by_length = len(seq_keep) - len(seq_keep_by_length)
seq_keep = seq_keep_by_length

# exclude sequences with non-nucleotide characters
num_excluded_by_nuc = 0
if args.non_nucleotide:
Expand Down Expand Up @@ -385,7 +385,7 @@ def run(args):
probabilistic_sampling = False

if args.subsample_max_sequences or (args.group_by and args.sequences_per_group):

#set groups to group_by values
if args.group_by:
groups = args.group_by
Expand All @@ -398,7 +398,7 @@ def run(args):

for seq_name in seq_keep:
group = []
m = meta_dict[seq_name]
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
Expand Down Expand Up @@ -551,7 +551,7 @@ def run(args):

# loop over all sequences and re-add sequences
for seq_name in available_strains:
if meta_dict[seq_name].get(col)==val:
if metadata.loc[seq_name].get(col)==val:
to_include.add(seq_name)

num_included_by_metadata = len(to_include)
Expand Down Expand Up @@ -608,7 +608,7 @@ def run(args):
num_excluded_by_lack_of_sequences = len(metadata_strains - sequence_strains)

if args.output_metadata:
metadata_df = pd.DataFrame([meta_dict[strain] for strain in seq_keep])
metadata_df = metadata.loc[seq_keep]
metadata_df.to_csv(
args.output_metadata,
sep="\t",
Expand Down
11 changes: 8 additions & 3 deletions augur/util_support/metadata_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ class MetadataFile:
which is used to match metadata with samples.
"""

def __init__(self, fname, query=None):
def __init__(self, fname, query=None, as_data_frame=False):
self.fname = fname
self.query = query
self.as_data_frame = as_data_frame

self.key_type = self.find_key_type()

Expand All @@ -26,8 +27,12 @@ def read(self):
# original "strain"/"name" remains in the output.
self.metadata["_index"] = self.metadata[self.key_type]

metadata_dict = self.metadata.set_index("_index").to_dict("index")
return metadata_dict, self.columns
metadata = self.metadata.set_index("_index")

if self.as_data_frame:
return metadata, self.columns
else:
return metadata.to_dict("index"), self.columns

@property
@functools.lru_cache()
Expand Down
74 changes: 55 additions & 19 deletions augur/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import Bio
import Bio.Phylo
from datetime import datetime
import gzip
import os, json, sys
import pandas as pd
Expand Down Expand Up @@ -59,8 +60,8 @@ def get_json_name(args, default=None):
def ambiguous_date_to_date_range(uncertain_date, fmt, min_max_year=None):
return DateDisambiguator(uncertain_date, fmt=fmt, min_max_year=min_max_year).range()

def read_metadata(fname, query=None):
return MetadataFile(fname, query).read()
def read_metadata(fname, query=None, as_data_frame=False):
return MetadataFile(fname, query, as_data_frame).read()

def is_date_ambiguous(date, ambiguous_by="any"):
"""
Expand Down Expand Up @@ -93,28 +94,63 @@ def is_date_ambiguous(date, ambiguous_by="any"):
"X" in day and ambiguous_by in ("any", "day")
))

def get_numerical_date_from_value(value, fmt=None, min_max_year=None, raise_error=True):
if type(value)!=str:
if raise_error:
raise ValueError(value)
else:
numerical_date = None
elif 'XX' in value:
ambig_date = ambiguous_date_to_date_range(value, fmt, min_max_year)
if ambig_date is None or None in ambig_date:
numerical_date = [None, None] #don't send to numeric_date or will be set to today
else:
numerical_date = [numeric_date(d) for d in ambig_date]
else:
try:
numerical_date = numeric_date(datetime.strptime(value, fmt))
except:
numerical_date = None

return numerical_date

def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None):
if fmt:
from datetime import datetime
numerical_dates = {}
for k,m in meta_dict.items():
v = m[date_col]
if type(v)!=str:
print("WARNING: %s has an invalid data string:"%k,v)
continue
elif 'XX' in v:
ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year)
if ambig_date is None or None in ambig_date:
numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today
else:
numerical_dates[k] = [numeric_date(d) for d in ambig_date]
else:

if isinstance(meta_dict, dict):
for k,m in meta_dict.items():
v = m[date_col]
try:
numerical_dates[k] = numeric_date(datetime.strptime(v, fmt))
except:
numerical_dates[k] = None
numerical_dates[k] = get_numerical_date_from_value(
v,
fmt,
min_max_year
)
except ValueError:
print(
"WARNING: %s has an invalid data string: %s"% (k, v),
file=sys.stderr
)
continue
elif isinstance(meta_dict, pd.DataFrame):
strains = meta_dict.index.values
dates = meta_dict[date_col].apply(
lambda date: get_numerical_date_from_value(
date,
fmt,
min_max_year,
raise_error=False
)
).values
numerical_dates = dict(zip(strains, dates))
else:
numerical_dates = {k:float(v) for k,v in meta_dict.items()}
if isinstance(meta_dict, dict):
numerical_dates = {k:float(v) for k,v in meta_dict.items()}
elif isinstance(meta_dict, pd.DataFrame):
strains = meta_dict.index.values
dates = meta_dict[date_col].astype(float)
numerical_dates = dict(zip(strains, dates))

return numerical_dates

Expand Down
13 changes: 8 additions & 5 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from Bio.SeqRecord import SeqRecord

import augur.filter
from augur.utils import read_metadata

@pytest.fixture
def argparser():
Expand Down Expand Up @@ -163,17 +164,19 @@ def test_filter_on_query_good(self, tmpdir, sequences):
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
filtered = augur.filter.filter_by_query(sequences.keys(), meta_fn, 'quality=="good"')
assert filtered == ["SEQ_1", "SEQ_3"]
metadata, columns = read_metadata(meta_fn, as_data_frame=True)
filtered = augur.filter.filter_by_query(set(sequences.keys()), metadata, 'quality=="good"')
assert sorted(filtered) == ["SEQ_1", "SEQ_3"]

def test_filter_on_query_subset(self, tmpdir):
"""Test filtering on query works when given fewer strains than metadata"""
meta_fn = write_metadata(tmpdir, (("strain","location","quality"),
("SEQ_1","colorado","good"),
("SEQ_2","colorado","bad"),
("SEQ_3","nevada","good")))
filtered = augur.filter.filter_by_query(["SEQ_2"], meta_fn, 'quality=="bad" & location=="colorado"')
assert filtered == ["SEQ_2"]
metadata, columns = read_metadata(meta_fn, as_data_frame=True)
filtered = augur.filter.filter_by_query({"SEQ_2"}, metadata, 'quality=="bad" & location=="colorado"')
assert sorted(filtered) == ["SEQ_2"]

def test_filter_run_with_query(self, tmpdir, fasta_fn, argparser):
"""Test that filter --query works as expected"""
Expand Down Expand Up @@ -242,4 +245,4 @@ def test_filter_run_max_date(self, tmpdir, fasta_fn, argparser):
% (fasta_fn, meta_fn, out_fn, max_date))
augur.filter.run(args)
output = SeqIO.to_dict(SeqIO.parse(out_fn, "fasta"))
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]
assert list(output.keys()) == ["SEQ_1", "SEQ_2"]

0 comments on commit 43e6ded

Please sign in to comment.