diff --git a/batch.py b/batch.py index c5d720d..178790e 100644 --- a/batch.py +++ b/batch.py @@ -10,7 +10,8 @@ from covizu.utils.batch_utils import * from covizu.utils.seq_utils import SC2Locator from tempfile import NamedTemporaryFile - +import psycopg2 +import psycopg2.extras def parse_args(): parser = argparse.ArgumentParser(description="CoVizu analysis pipeline automation") @@ -102,13 +103,33 @@ def parse_args(): return parser.parse_args() -def process_feed(args, callback=None): +def open_connection(): + """ open connection to database, initialize tables if they don't exist + :out: + :cursor: interactive sql object containing tables + """ + conn = psycopg2.connect(host="localhost", dbname="gsaid_db", user="postgres", + password="12345", port="5432") + cur = conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor) + + # create tables if they don't exist + seqs_table = '''CREATE TABLE IF NOT EXISTS SEQUENCES (accession VARCHAR(255) + PRIMARY KEY, qname VARCHAR(255), lineage VARCHAR(255), + date VARCHAR(255), location VARCHAR(255), + diffs VARCHAR, missing VARCHAR)''' + cur.execute(seqs_table) + + conn.commit() + return cur, conn + + +def process_feed(args, cur, callback=None): """ Process feed data """ if callback: callback("Processing GISAID feed data") loader = gisaid_utils.load_gisaid(args.infile, minlen=args.minlen, mindate=args.mindate) - batcher = gisaid_utils.batch_fasta(loader, size=args.batchsize) - aligned = gisaid_utils.extract_features(batcher, ref_file=args.ref, binpath=args.mmbin, + batcher = gisaid_utils.batch_fasta(loader, cur, size=args.batchsize) + aligned = gisaid_utils.extract_features(batcher, ref_file=args.ref, cur=cur, binpath=args.mmbin, nthread=args.mmthreads, minlen=args.minlen) filtered = gisaid_utils.filter_problematic(aligned, vcf_file=args.vcf, cutoff=args.poisson_cutoff, callback=callback) @@ -119,6 +140,8 @@ def process_feed(args, callback=None): args = parse_args() cb = Callback() + cur, conn = open_connection() + # check that user has loaded openmpi module try: subprocess.check_call(['mpirun', '-np', '2', 'ls'], stdout=subprocess.DEVNULL) @@ -147,7 +170,7 @@ def process_feed(args, callback=None): args.infile = gisaid_utils.download_feed(args.url, args.user, args.password) # filter data, align genomes, extract features, sort by lineage - by_lineage = process_feed(args, cb.callback) + by_lineage = process_feed(args, cur, cb.callback) # reconstruct time-scaled tree relating lineages timetree, residuals = build_timetree(by_lineage, args, cb.callback) @@ -222,4 +245,5 @@ def process_feed(args, callback=None): fp.close() subprocess.check_call(['scp', fp.name, '{}/clusters.json'.format(server_epicov)]) + conn.commit() cb.callback("All done!") diff --git a/covizu/utils/gisaid_utils.py b/covizu/utils/gisaid_utils.py index 300fc26..0e04d56 100644 --- a/covizu/utils/gisaid_utils.py +++ b/covizu/utils/gisaid_utils.py @@ -7,8 +7,6 @@ import subprocess from datetime import datetime import getpass -import sqlite3 -import psycopg2 import covizu from covizu import minimap2 @@ -19,40 +17,6 @@ import gc -def open_connection(database): - """ open connection to database, initialize tables if they don't exist - :params: - :database: str, name of database - :out: - :cursor: interactive sql object containing tables - """ - # if not os.path.exists(database): - # print("ERROR: Failed to open sqlite3 connection, path {} does not exist".format(database)) - # sys.exit() - - # conn = sqlite3.connect(database, check_same_thread=False) - conn = psycopg2.connect(host="localhost", dbname="gsaid_db", user="postgres", - password="12345", port="5432") - cur = conn.cursor() - - # create tables if they don't exist - seqs_table = '''CREATE TABLE IF NOT EXISTS SEQUENCES (accession VARCHAR(255) - PRIMARY KEY, name VARCHAR(255), lineage VARCHAR(255), - date VARCHAR(255), location VARCHAR(255))''' - cur.execute(seqs_table) - - fvecs_table = '''CREATE TABLE IF NOT EXISTS FEATURES (accession VARCHAR(255), - name VARCHAR(255), lineage VARCHAR(255), date VARCHAR(255), - location VARCHAR(255), vectors VARCHAR(1000))''' - cur.execute(fvecs_table) - - # create index on vectors column - cur.execute('''CREATE INDEX IF NOT EXISTS FVECS_INDEX ON FEATURES (vectors)''') - - conn.commit() - return cur, conn - - def download_feed(url, user, password): """ Download xz file from GISAID. Note this requires confidential URL, user and password @@ -76,7 +40,6 @@ def download_feed(url, user, password): def load_gisaid(path, minlen=29000, mindate='2019-12-01', callback=None, - database='covizu/data/gsaid.db', fields=("covv_accession_id", "covv_virus_name", "covv_lineage", "covv_collection_date", "covv_location", "sequence"), debug=None @@ -93,10 +56,6 @@ def load_gisaid(path, minlen=29000, mindate='2019-12-01', callback=None, :yield: dict, contents of each GISAID record """ - - # initialize database objects - cur, conn = open_connection(database) - mindate = fromisoformat(mindate) rejects = {'short': 0, 'baddate': 0, 'nonhuman': 0, 'nolineage': 0} with lzma.open(path, 'rb') as handle: @@ -108,19 +67,6 @@ def load_gisaid(path, minlen=29000, mindate='2019-12-01', callback=None, # remove unused data record = dict([(k, record[k]) for k in fields]) - # data = cur.execute("SELECT accession FROM SEQUENCES WHERE accession = ?", - # (record["covv_accession_id"],)).fetchone() - cur.execute("SELECT accession FROM SEQUENCES WHERE accession = '%s'"%(record["covv_accession_id"],)) - data = cur.fetchone() - - if data == None: - # cur.execute("INSERT INTO SEQUENCES (accession, name, lineage, date, location) VALUES(?, ?, ?, ?, ?)", - # [v for k, v in record.items() if k != 'sequence']) - cur.execute("INSERT INTO SEQUENCES (accession, name, lineage, date, location) VALUES(%s, %s, %s, %s, %s)", - [v for k, v in record.items() if k != 'sequence']) - else: - continue - qname = record['covv_virus_name'].strip().replace(',', '_').replace('|', '_') # issue #206,#464 country = qname.split('/')[1] if country == '' or country[0].islower(): @@ -149,9 +95,6 @@ def load_gisaid(path, minlen=29000, mindate='2019-12-01', callback=None, yield record - conn.commit() - conn.close() - if callback: callback("Rejected {short} short genomes\n" " {nolineage} with no lineage assignment\n" @@ -159,7 +102,7 @@ def load_gisaid(path, minlen=29000, mindate='2019-12-01', callback=None, " {nonhuman} non-human genomes".format(**rejects)) -def batch_fasta(gen, size=100): +def batch_fasta(gen, cur, size=100): """ Concatenate sequence records in stream into FASTA-formatted text in batches of records. @@ -169,10 +112,17 @@ def batch_fasta(gen, size=100): """ stdin = '' batch = [] + for i, record in enumerate(gen, 1): qname = record['covv_virus_name'] sequence = record.pop('sequence') - stdin += '>{}\n{}\n'.format(qname, sequence) + + cur.execute("SELECT * FROM SEQUENCES WHERE qname = '%s'"%qname) + result = cur.fetchone() + if result: + record.update({'diffs': result["diffs"], 'missing': result["missing"]}) + else: + stdin += '>{}\n{}\n'.format(qname, sequence) batch.append(record) if i > 0 and i % size == 0: yield stdin, batch @@ -183,7 +133,7 @@ def batch_fasta(gen, size=100): yield stdin, batch -def extract_features(batcher, ref_file, binpath='minimap2', nthread=3, minlen=29000): +def extract_features(batcher, ref_file, cur, binpath='minimap2', nthread=3, minlen=29000): """ Stream output from JSON.xz file via load_gisaid() into minimap2 via subprocess. @@ -200,13 +150,28 @@ def extract_features(batcher, ref_file, binpath='minimap2', nthread=3, minlen=29 reflen = len(convert_fasta(handle)[0][1]) for fasta, batch in batcher: + new_records = {} + for record in batch: + if 'diffs' in record: + yield record + else: + new_records[record['covv_virus_name']] = record + + # If fasta is empty, no need to run minimap2 + if len(fasta) == 0: + continue + mm2 = minimap2.minimap2(fasta, ref_file, stream=True, path=binpath, nthread=nthread, minlen=minlen) result = list(minimap2.encode_diffs(mm2, reflen=reflen)) - for row, record in zip(result, batch): + for qname, diffs, missing in result: # reconcile minimap2 output with GISAID record - qname, diffs, missing = row + record = new_records[qname] record.update({'diffs': diffs, 'missing': missing}) + + # inserting diffs and missing as json strings + cur.execute("INSERT INTO SEQUENCES VALUES(%s, %s, %s, %s, %s, %s, %s)", + [json.dumps(v) if k in ['diffs', 'missing'] else v for k, v in record.items()]) yield record @@ -243,7 +208,15 @@ def filter_problematic(records, origin='2019-12-01', rate=0.0655, cutoff=0.005, if type(record) is not dict: qname, diffs, missing = record # unpack tuple else: - diffs = record['diffs'] + try: + # loading json strings of old records + record['diffs'] = json.loads(record['diffs']) + record['missing'] = json.loads(record['missing']) + except TypeError: + # passing for new records + pass + finally: + diffs = record['diffs'] # exclude problematic sites filtered = [] @@ -296,9 +269,6 @@ def sort_by_lineage(records, callback=None, database='covizu/data/gsaid.db', int """ result = {} - # initialize database objects - cur, conn = open_connection(database) - for i, record in enumerate(records): if callback and i % interval == 0: callback('aligned {} records'.format(i)) @@ -312,12 +282,6 @@ def sort_by_lineage(records, callback=None, database='covizu/data/gsaid.db', int if str(lineage) == "None" or lineage == '': # discard uncategorized genomes, #324, #335 continue - - # cur.execute("INSERT INTO FEATURES VALUES(?, ?, ?, ?, ?, ?)", - # [v for k, v in record.items() if k != 'missing'] + [key]) - cur.execute("INSERT INTO FEATURES VALUES(%s, %s, %s, %s, %s, %s)", - [v for k, v in record.items() if k != 'missing'] + [key]) - conn.commit() if lineage not in result: result.update({lineage: {}}) @@ -326,8 +290,6 @@ def sort_by_lineage(records, callback=None, database='covizu/data/gsaid.db', int result[lineage][key].append(record) - conn.close() - return result