Skip to content

Commit

Permalink
#1818, breakout subgraph generation into infer utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
finnagin committed May 31, 2022
1 parent 352df51 commit e4083bf
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 191 deletions.
197 changes: 6 additions & 191 deletions code/ARAX/ARAXQuery/ARAX_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
from node_synonymizer import NodeSynonymizer

sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'scripts']))
from infer_utilities import InferUtilities
# from creativeDTD import creativeDTD

import pickle
Expand Down Expand Up @@ -101,8 +102,8 @@ def __init__(self):
}


def __get_formated_edge_key(self, edge: Edge, kp: str = 'infores:rtx-kg2') -> str:
return f"{kp}:{edge.subject}-{edge.predicate}-{edge.object}"
# def __get_formated_edge_key(self, edge: Edge, kp: str = 'infores:rtx-kg2') -> str:
# return f"{kp}:{edge.subject}-{edge.predicate}-{edge.object}"

def report_response_stats(self, response):
"""
Expand Down Expand Up @@ -283,18 +284,6 @@ def __drug_treatment_graph_expansion(self, describe=False):
if self.response.status != 'OK':
return self.response

# FW: may need these to add the answer graphs if not will delete
expander = ARAXExpander()
messenger = ARAXMessenger()
synonymizer = NodeSynonymizer()
decorator = ARAXDecorator()

# expand parameters
mode = 'ARAX'
timeout = 60
kp = 'infores:rtx-kg2'
prune_threshold = 500


# dtd = creativeDTD(data_path, model_path, use_gpu=False)

Expand All @@ -306,184 +295,10 @@ def __drug_treatment_graph_expansion(self, describe=False):
top_drugs = pd.read_csv(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'data',"top_n_drugs.csv"]))
with open(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery', 'Infer', 'data',"result_from_self_predict_top_M_paths.pkl"]),"rb") as fid:
top_paths = pickle.load(fid)


node_names = set([y for paths in top_paths.values() for x in paths for y in x[0].split("->")[::2] if y != ''])
node_info = synonymizer.get_canonical_curies(names=list(node_names))
node_name_to_id = {k:v['preferred_curie'] for k,v in node_info.items() if v is not None}
path_lengths = set([math.floor(len(x[0].split("->"))/2.) for paths in top_paths.values() for x in paths])
max_path_len = max(path_lengths)
disease = list(top_paths.keys())[0][1]
disease_name = list(top_paths.values())[0][0][0].split("->")[-1]
add_qnode_params = {
'key' : "disease",
'name': disease
}
self.response = messenger.add_qnode(self.response, add_qnode_params)
self.response.envelope.message.knowledge_graph.nodes[disease] = Node(name=disease_name, categories=['biolink:DiseaseOrPhenotypicFeature'])
self.response.envelope.message.knowledge_graph.nodes[disease].qnode_keys = ['disease']
node_name_to_id[disease_name] = disease
add_qnode_params = {
'key' : "drug",
'categories': ['biolink:Drug']
}
self.response = messenger.add_qnode(self.response, add_qnode_params)
add_qedge_params = {
'key' : "probably_treats",
'subject' : "drug",
'object' : "disease",
'predicates': ["biolink:probably_treats"]
}
self.response = messenger.add_qedge(self.response, add_qedge_params)
path_keys = [{} for i in range(max_path_len)]
for i in range(max_path_len+1):
if (i+1) in path_lengths:
path_qnodes = ["drug"]
for j in range(i):
new_qnode_key = f"creative_DTD_qnode_{self.qnode_global_iter}"
path_qnodes.append(new_qnode_key)
add_qnode_params = {
'key' : new_qnode_key,
'option_group_id': f"creative_DTD_option_group_{self.option_global_iter}",
"is_set": "true"
}
self.response = messenger.add_qnode(self.response, add_qnode_params)
self.qnode_global_iter += 1
path_qnodes.append("disease")
qnode_pairs = list(zip(path_qnodes,path_qnodes[1:]))
qedge_key_list = []
for qnode_pair in qnode_pairs:
add_qedge_params = {
'key' : f"creative_DTD_qedge_{self.qedge_global_iter}",
'subject' : qnode_pair[0],
'object' : qnode_pair[1],
'option_group_id': f"creative_DTD_option_group_{self.option_global_iter}"
}
qedge_key_list.append(f"creative_DTD_qedge_{self.qedge_global_iter}")
self.qedge_global_iter += 1
self.response = messenger.add_qedge(self.response, add_qedge_params)
path_keys[i]["qnode_pairs"] = qnode_pairs
path_keys[i]["qedge_keys"] = qedge_key_list
self.option_global_iter += 1

# FW: code that will add resulting paths to the query graph and knowledge graph goes here
essence_scores = {}
for (drug, disease), paths in top_paths.items():
path_added = False
# Splits the paths which are encodes as strings into a list of nodes names and edge predicates
# The x[0] is here since each element consists of the string path and a score we are currently ignoring the score
split_paths = [x[0].split("->") for x in paths]
for path in split_paths:
drug_name = path[0]
if any([x not in node_name_to_id for x in path[::2]]):
continue
# new_response = ARAXResponse()
# messenger.create_envelope(new_response)
n_elements = len(path)
# Creates edge tuples of the form (node name 1, edge predicate, node name 2)
edge_tuples = [(path[i],path[i+1],path[i+2]) for i in range(0,n_elements-2,2)]
path_idx = len(edge_tuples)-1
added_nodes = set()
for i in range(path_idx+1):
# if path_keys[path_idx]["qnode_pairs"][i][0] not in added_nodes:
# add_qnode_params = {
# 'key' : path_keys[path_idx]["qnode_pairs"][i][0],
# 'name': edge_tuples[i][0]
# }
# new_response = messenger.add_qnode(new_response, add_qnode_params)
# added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][0])
subject_qnode_key = path_keys[path_idx]["qnode_pairs"][i][0]
subject_name = edge_tuples[i][0]
subject_curie = node_name_to_id[subject_name]
subject_category = node_info[subject_name]['preferred_category']
if subject_curie not in self.response.envelope.message.knowledge_graph.nodes:
self.response.envelope.message.knowledge_graph.nodes[subject_curie] = Node(name=subject_name, categories=[subject_category])
self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys = [subject_qnode_key]
elif subject_qnode_key not in self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys:
self.response.envelope.message.knowledge_graph.nodes[subject_curie].qnode_keys.append(subject_qnode_key)
# if path_keys[path_idx]["qnode_pairs"][i][1] not in added_nodes:
# add_qnode_params = {
# 'key' : path_keys[path_idx]["qnode_pairs"][i][1],
# 'name': edge_tuples[i][2]
# }
# new_response = messenger.add_qnode(new_response, add_qnode_params)
# added_nodes.add(path_keys[path_idx]["qnode_pairs"][i][1])
object_qnode_key = path_keys[path_idx]["qnode_pairs"][i][1]
object_name = edge_tuples[i][2]
object_curie = node_name_to_id[object_name]
object_category = node_info[object_name]['preferred_category']
if object_curie not in self.response.envelope.message.knowledge_graph.nodes:
self.response.envelope.message.knowledge_graph.nodes[object_curie] = Node(name=object_name, categories=[object_category])
self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys = [object_qnode_key]
elif object_qnode_key not in self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys:
self.response.envelope.message.knowledge_graph.nodes[object_curie].qnode_keys.append(object_qnode_key)
# new_qedge_key = path_keys[path_idx]["qedge_keys"][i]
# add_qedge_params = {
# 'key' : new_qedge_key,
# 'subject' : path_keys[path_idx]["qnode_pairs"][i][0],
# 'object' : path_keys[path_idx]["qnode_pairs"][i][1],
# 'predicates': [edge_tuples[i][1]]
# }
# new_response = messenger.add_qedge(new_response, add_qedge_params)
new_edge = Edge(subject=subject_curie, object=object_curie, predicate=edge_tuples[i][1], attributes=[])
new_edge.attributes.append(EdgeAttribute(attribute_type_id="biolink:aggregator_knowledge_source",
value=kp,
value_type_id="biolink:InformationResource",
attribute_source=kp))
new_edge_key = self.__get_formated_edge_key(edge=new_edge, kp=kp)
self.response.envelope.message.knowledge_graph.edges[new_edge_key] = new_edge
self.response.envelope.message.knowledge_graph.edges[new_edge_key].qedge_keys = [path_keys[path_idx]["qedge_keys"][i]]
# expand_params = {
# 'kp':kp,
# 'prune_threshold':prune_threshold,
# 'edge_key':path_keys[path_idx]["qedge_keys"],
# 'kp_timeout':timeout
# }
# new_response = expander.apply(new_response, expand_params, mode=mode)
# if new_response.status == 'OK':
# for knode_id, knode in new_response.envelope.message.knowledge_graph.nodes.items():
# if 'disease' in knode.qnode_keys:
# normalized_disease = knode_id
# if 'drug' in knode.qnode_keys:
# normalized_drug = knode_id
# normalized_drug_name = knode.name
# if knode_id in self.response.envelope.message.knowledge_graph.nodes:
# new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys += self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
# self.response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys = new_response.envelope.message.knowledge_graph.nodes[knode_id].qnode_keys
# self.response.envelope.message.knowledge_graph.nodes.update(new_response.envelope.message.knowledge_graph.nodes)
# self.response.envelope.message.knowledge_graph.edges.update(new_response.envelope.message.knowledge_graph.edges)
# self.response.merge(new_response)
path_added = True
if path_added:
treat_score = top_drugs.loc[top_drugs['drug_id'] == drug]["tp_score"].iloc[0]
essence_scores[drug_name] = treat_score
edge_attribute_list = [
# EdgeAttribute(original_attribute_name="defined_datetime", value=defined_datetime, attribute_type_id="metatype:Datetime"),
EdgeAttribute(original_attribute_name="provided_by", value="infores:arax", attribute_type_id="biolink:aggregator_knowledge_source", attribute_source="infores:arax", value_type_id="biolink:InformationResource"),
EdgeAttribute(original_attribute_name=None, value=True, attribute_type_id="biolink:computed_value", attribute_source="infores:arax-reasoner-ara", value_type_id="metatype:Boolean", value_url=None, description="This edge is a container for a computed value between two nodes that is not directly attachable to other edges."),
EdgeAttribute(attribute_type_id="EDAM:data_0951", original_attribute_name="probability_treats", value=str(treat_score))
]
fixed_edge = Edge(predicate="biolink:probably_treats", subject=node_name_to_id[drug_name], object=node_name_to_id[disease_name],
attributes=edge_attribute_list)
fixed_edge.qedge_keys = ["probably_treats"]
self.response.envelope.message.knowledge_graph.edges[f"creative_DTD_prediction_{self.kedge_global_iter}"] = fixed_edge
self.kedge_global_iter += 1
else:
self.response.warning(f"Something went wrong when adding the subgraph for the drug-disease pair ({drug},{disease}) to the knowledge graph. Skipping this result....")
self.response = decorator.decorate_nodes(self.response)
if self.response.status != 'OK':
return self.response
self.response = decorator.decorate_edges(self.response)
if self.response.status != 'OK':
return self.response
resultifier = ARAXResultify()
resultify_params = {
"ignore_edge_direction": "true"
}
self.response = resultifier.apply(self.response, resultify_params)
for result in self.response.envelope.message.results:
result.score = essence_scores[result.essence]
self.response.envelope.message.results.sort(key=lambda x: x.score, reverse=True)
iu = InferUtilities()
self.response, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter = iu.genrete_treat_subgraphs(self.response, top_drugs, top_paths, self.kedge_global_iter, self.qedge_global_iter, self.qnode_global_iter, self.option_global_iter)

return self.response


Expand Down
Loading

0 comments on commit e4083bf

Please sign in to comment.