Skip to content

Commit

Permalink
#1818, sped up TRAPI generation step
Browse files Browse the repository at this point in the history
  • Loading branch information
finnagin committed May 30, 2022
1 parent ee166ad commit 5ca3d9a
Showing 1 changed file with 102 additions and 48 deletions.
150 changes: 102 additions & 48 deletions code/ARAX/ARAXQuery/ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
from ARAX_messenger import ARAXMessenger
from ARAX_expander import ARAXExpander
from ARAX_resultify import ARAXResultify
from ARAX_decorator import ARAXDecorator
import traceback
from collections import Counter
from collections.abc import Hashable
Expand All @@ -25,6 +26,10 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
from openapi_server.models.q_node import QNode
from openapi_server.models.attribute import Attribute as EdgeAttribute
from openapi_server.models.edge import Edge
from openapi_server.models.node import Node

sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'NodeSynonymizer']))
from node_synonymizer import NodeSynonymizer

sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'scripts']))
# from creativeDTD import creativeDTD
Expand All @@ -33,6 +38,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
import pandas as pd



class ARAXInfer:

#### Constructor
Expand Down Expand Up @@ -95,6 +101,9 @@ def __init__(self):
}


def __get_formated_edge_key(self, edge: Edge, kp: str = 'infores:rtx-kg2') -> str:
return f"{kp}:{edge.subject}-{edge.predicate}-{edge.object}"

def report_response_stats(self, response):
"""
Little helper function that will report the KG, QG, and results stats to the debug in the process of executing actions. Basically to help diagnose problems
Expand Down Expand Up @@ -277,6 +286,8 @@ def __drug_treatment_graph_expansion(self, describe=False):
# FW: may need these to add the answer graphs if not will delete
expander = ARAXExpander()
messenger = ARAXMessenger()
synonymizer = NodeSynonymizer()
decorator = ARAXDecorator()

# expand parameters
mode = 'ARAX'
Expand All @@ -292,18 +303,26 @@ def __drug_treatment_graph_expansion(self, describe=False):
# top_paths = dtd.predict_top_M_paths(self.parameters['n_paths'])

# FW: temp fix to use the pickle fil for dev work rather than recomputing

top_drugs = pd.read_csv(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'data',"top_n_drugs.csv"]))
with open(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'data',"result_from_self_predict_top_M_paths.pkl"]),"rb") as fid:
top_paths = pickle.load(fid)


node_names = set([y for paths in top_paths.values() for x in paths for y in x[0].split("->")[::2] if y != ''])
node_info = synonymizer.get_canonical_curies(names=list(node_names))
node_name_to_id = {k:v['preferred_curie'] for k,v in node_info.items() if v is not None}
path_lengths = set([math.floor(len(x[0].split("->"))/2.) for paths in top_paths.values() for x in paths])
max_path_len = max(path_lengths)
disease = list(top_paths.keys())[0][1]
disease_name = list(top_paths.values())[0][0][0].split("->")[-1]
add_qnode_params = {
'key' : "disease",
'name': disease
}
self.response = messenger.add_qnode(self.response, add_qnode_params)
self.response.envelope.message.knowledge_graph.nodes[disease] = Node(name=disease_name, categories=['biolink:DiseaseOrPhenotypicFeature'])
self.response.envelope.message.knowledge_graph.nodes[disease].qnode_keys = ['disease']
node_name_to_id[disease_name] = disease
add_qnode_params = {
'key' : "drug",
'categories': ['biolink:Drug']
Expand Down Expand Up @@ -355,73 +374,108 @@ def __drug_treatment_graph_expansion(self, describe=False):
# The x[0] is here since each element consists of the string path and a score we are currently ignoring the score
split_paths = [x[0].split("->") for x in paths]
for path in split_paths:
new_response = ARAXResponse()
messenger.create_envelope(new_response)
drug_name = path[0]
if any([x not in node_name_to_id for x in path[::2]]):
continue
# new_response = ARAXResponse()
# messenger.create_envelope(new_response)
n_elements = len(path)
# Creates edge tuples of the form (node name 1, edge predicate, node name 2)
edge_tuples = [(path[i],path[i+1],path[i+2]) for i in range(0,n_elements-2,2)]
path_idx = len(edge_tuples)-1
added_nodes = set()
for i in range(path_idx+1):
if path_keys[path_idx]["qnode_pairs"][i][0] not in added_nodes:
add_qnode_params = {
'key' : path_keys[path_idx]["qnode_pairs"][i][0],
'name': edge_tuples[i][0]
}
new_response = messenger.add_qnode(new_response, add_qnode_params)
added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][0])
if path_keys[path_idx]["qnode_pairs"][i][1] not in added_nodes:
add_qnode_params = {
'key' : path_keys[path_idx]["qnode_pairs"][i][1],
'name': edge_tuples[i][2]
}
new_response = messenger.add_qnode(new_response, add_qnode_params)
added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][1])
new_qedge_key = path_keys[path_idx]["qedge_keys"][i]
add_qedge_params = {
'key' : new_qedge_key,
'subject' : path_keys[path_idx]["qnode_pairs"][i][0],
'object' : path_keys[path_idx]["qnode_pairs"][i][1],
'predicates': [edge_tuples[i][1]]
}
new_response = messenger.add_qedge(new_response, add_qedge_params)
expand_params = {
'kp':kp,
'prune_threshold':prune_threshold,
'edge_key':path_keys[path_idx]["qedge_keys"],
'kp_timeout':timeout
}
new_response = expander.apply(new_response, expand_params, mode=mode)
if new_response.status == 'OK':
for knode_id, knode in new_response.envelope.message.knowledge_graph.nodes.items():
if 'disease' in knode.qnode_keys:
normalized_disease = knode_id
if 'drug' in knode.qnode_keys:
normalized_drug = knode_id
normalized_drug_name = knode.name
if knode_id in self.response.envelope.message.knowledge_graph.nodes:
new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys += self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys = new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
self.response.envelope.message.knowledge_graph.nodes.update(new_response.envelope.message.knowledge_graph.nodes)
self.response.envelope.message.knowledge_graph.edges.update(new_response.envelope.message.knowledge_graph.edges)
self.response.merge(new_response)
path_added = True
# if path_keys[path_idx]["qnode_pairs"][i][0] not in added_nodes:
# add_qnode_params = {
# 'key' : path_keys[path_idx]["qnode_pairs"][i][0],
# 'name': edge_tuples[i][0]
# }
# new_response = messenger.add_qnode(new_response, add_qnode_params)
# added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][0])
subject_qnode_key = path_keys[path_idx]["qnode_pairs"][i][0]
subject_name = edge_tuples[i][0]
subject_curie = node_name_to_id[subject_name]
subject_category = node_info[subject_name]['preferred_category']
if subject_curie not in self.response.envelope.message.knowledge_graph.nodes:
self.response.envelope.message.knowledge_graph.nodes[subject_curie] = Node(name=subject_name, categories=[subject_category])
self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys = [subject_qnode_key]
elif subject_qnode_key not in self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys:
self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys.append(subject_qnode_key)
# if path_keys[path_idx]["qnode_pairs"][i][1] not in added_nodes:
# add_qnode_params = {
# 'key' : path_keys[path_idx]["qnode_pairs"][i][1],
# 'name': edge_tuples[i][2]
# }
# new_response = messenger.add_qnode(new_response, add_qnode_params)
# added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][1])
object_qnode_key = path_keys[path_idx]["qnode_pairs"][i][1]
object_name = edge_tuples[i][2]
object_curie = node_name_to_id[object_name]
object_category = node_info[object_name]['preferred_category']
if object_curie not in self.response.envelope.message.knowledge_graph.nodes:
self.response.envelope.message.knowledge_graph.nodes[object_curie] = Node(name=object_name, categories=[object_category])
self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys = [object_qnode_key]
elif object_qnode_key not in self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys:
self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys.append(object_qnode_key)
# new_qedge_key = path_keys[path_idx]["qedge_keys"][i]
# add_qedge_params = {
# 'key' : new_qedge_key,
# 'subject' : path_keys[path_idx]["qnode_pairs"][i][0],
# 'object' : path_keys[path_idx]["qnode_pairs"][i][1],
# 'predicates': [edge_tuples[i][1]]
# }
# new_response = messenger.add_qedge(new_response, add_qedge_params)
new_edge = Edge(subject=subject_curie, object=object_curie, predicate=edge_tuples[i][1], attributes=[])
new_edge.attributes.append(EdgeAttribute(attribute_type_id="biolink:aggregator_knowledge_source",
value=kp,
value_type_id="biolink:InformationResource",
attribute_source=kp))
new_edge_key = self.__get_formated_edge_key(edge=new_edge, kp=kp)
self.response.envelope.message.knowledge_graph.edges[new_edge_key] = new_edge
self.response.envelope.message.knowledge_graph.edges[new_edge_key].qedge_keys = [path_keys[path_idx]["qedge_keys"][i]]
# expand_params = {
# 'kp':kp,
# 'prune_threshold':prune_threshold,
# 'edge_key':path_keys[path_idx]["qedge_keys"],
# 'kp_timeout':timeout
# }
# new_response = expander.apply(new_response, expand_params, mode=mode)
# if new_response.status == 'OK':
# for knode_id, knode in new_response.envelope.message.knowledge_graph.nodes.items():
# if 'disease' in knode.qnode_keys:
# normalized_disease = knode_id
# if 'drug' in knode.qnode_keys:
# normalized_drug = knode_id
# normalized_drug_name = knode.name
# if knode_id in self.response.envelope.message.knowledge_graph.nodes:
# new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys += self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
# self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys = new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
# self.response.envelope.message.knowledge_graph.nodes.update(new_response.envelope.message.knowledge_graph.nodes)
# self.response.envelope.message.knowledge_graph.edges.update(new_response.envelope.message.knowledge_graph.edges)
# self.response.merge(new_response)
path_added = True
if path_added:
treat_score = top_drugs.loc[top_drugs['drug_id'] == drug]["tp_score"].iloc[0]
essence_scores[normalized_drug_name] = treat_score
essence_scores[drug_name] = treat_score
edge_attribute_list = [
# EdgeAttribute(original_attribute_name="defined_datetime", value=defined_datetime, attribute_type_id="metatype:Datetime"),
EdgeAttribute(original_attribute_name="provided_by", value="infores:arax", attribute_type_id="biolink:aggregator_knowledge_source", attribute_source="infores:arax", value_type_id="biolink:InformationResource"),
EdgeAttribute(original_attribute_name=None, value=True, attribute_type_id="biolink:computed_value", attribute_source="infores:arax-reasoner-ara", value_type_id="metatype:Boolean", value_url=None, description="This edge is a container for a computed value between two nodes that is not directly attachable to other edges."),
EdgeAttribute(attribute_type_id="EDAM:data_0951", original_attribute_name="probability_treats", value=str(treat_score))
]
fixed_edge = Edge(predicate="biolink:probably_treats", subject=normalized_drug, object=normalized_disease,
fixed_edge = Edge(predicate="biolink:probably_treats", subject=node_name_to_id[drug_name], object=node_name_to_id[disease_name],
attributes=edge_attribute_list)
fixed_edge.qedge_keys = ["probably_treats"]
self.response.envelope.message.knowledge_graph.edges[f"creative_DTD_prediction_{self.kedge_global_iter}"] = fixed_edge
self.kedge_global_iter += 1
else:
self.response.warning(f"Something went wrong when adding the subgraph for the drug-disease pair ({drug},{disease}) to the knowledge graph. Skipping this result....")
self.response = decorator.decorate_nodes(self.response)
if self.response.status != 'OK':
return self.response
self.response = decorator.decorate_edges(self.response)
if self.response.status != 'OK':
return self.response
resultifier = ARAXResultify()
resultify_params = {
"ignore_edge_direction": "true"
Expand Down

0 comments on commit 5ca3d9a

Please sign in to comment.