Skip to content

Commit 480573e

Browse files
committed
Merge remote-tracking branch 'origin/edeutsch-kp-cache'
2 parents b9e094a + 333174f commit 480573e

File tree

7 files changed

+1029
-36
lines changed

7 files changed

+1029
-36
lines changed

code/ARAX/ARAXQuery/ARAX_background_tasker.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from importlib.metadata import version
1212

1313
from ARAX_query_tracker import ARAXQueryTracker
14+
from Expand.trapi_query_cacher import KPQueryCacher
1415

1516
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
1617
from RTXConfiguration import RTXConfiguration
@@ -128,6 +129,9 @@ def run_tasks(self):
128129
eprint(result.stdout.decode('utf-8'))
129130
eprint("INFO: End listing databases area contents")
130131

132+
#### Set up the KP Cacher to be used for periodic refreshing
133+
kp_cacher = KPQueryCacher()
134+
131135
# Loop forever doing various things
132136
my_pid = os.getpid()
133137
while True:
@@ -198,14 +202,23 @@ def run_tasks(self):
198202
n_clients += 1
199203
n_ongoing_queries += n_queries
200204

205+
#### Refresh the KP cache
206+
start_time = time.time()
207+
kp_cacher.refresh_cache()
208+
elapsed_time = time.time() - start_time
209+
if elapsed_time < FREQ_CHECK_ONGOING_SEC - 1:
210+
time_to_sleep = FREQ_CHECK_ONGOING_SEC - round(elapsed_time)
211+
else:
212+
time_to_sleep = 2
213+
201214
load_tuple = psutil.getloadavg()
202215

203216
timestamp = str(datetime.datetime.now().isoformat())
204217
eprint(f"{timestamp}: INFO: ARAXBackgroundTasker "
205218
f"(PID {my_pid}) status: waiting. Current "
206219
f"load is {load_tuple}, n_clients={n_clients}, "
207220
f"n_ongoing_queries={n_ongoing_queries}")
208-
time.sleep(FREQ_CHECK_ONGOING_SEC)
221+
time.sleep(time_to_sleep)
209222

210223

211224
def main():

code/ARAX/ARAXQuery/ARAX_connect.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)
1010
import os
1111
from collections import Counter
1212
import copy
13+
import time
1314

1415
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
1516
from Path_Finder.converter.EdgeExtractorFromPloverDB import EdgeExtractorFromPloverDB
1617
from Path_Finder.converter.ResultPerPathConverter import ResultPerPathConverter
1718
from Path_Finder.converter.Names import Names
1819
from Path_Finder.BidirectionalPathFinder import BidirectionalPathFinder
1920

21+
from Expand.trapi_query_cacher import KPQueryCacher
22+
from ARAX_messenger import ARAXMessenger
23+
2024
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../UI/OpenAPI/python-flask-server/")
2125
from openapi_server.models.knowledge_graph import KnowledgeGraph
26+
from openapi_server.models.pathfinder_analysis import PathfinderAnalysis
2227

2328
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../NodeSynonymizer/")
2429
from node_synonymizer import NodeSynonymizer
@@ -185,15 +190,64 @@ def apply(self, input_response, input_parameters):
185190
self.response.data['parameters'] = parameters
186191
self.parameters = parameters
187192

188-
getattr(self, '_' + self.__class__.__name__ + '__' + parameters[
189-
'action'])() # thank you https://stackoverflow.com/questions/11649848/call-methods-by-string
190-
191-
self.response.debug(f"Applying Connect to Message with parameters {parameters}")
193+
#### Check the cache to see if we have this query cached already
194+
start = time.time()
195+
cacher = KPQueryCacher()
196+
kp_curie = "PathFinder"
197+
kp_url = "PathFinder"
198+
response_envelope_as_dict = self.response.envelope.to_dict()
199+
cleaned_parameters = self._clean_parameters(parameters)
200+
pathfinder_input_data = { 'query_graph': response_envelope_as_dict['message']['query_graph'], 'parameters': cleaned_parameters }
201+
self.response.info(f"Looking for a previously cached result from {kp_curie}")
202+
response_data, response_code, elapsed_time, error = cacher.get_cached_result(kp_curie, pathfinder_input_data)
203+
if response_code != -2:
204+
n_results = cacher._get_n_results(response_data)
205+
self.response.info(f"Found a cached result with response_code={response_code}, n_results={n_results} from the cache in {elapsed_time:.3f} seconds")
206+
self.response.envelope.message = ARAXMessenger().from_dict(response_data['message'])
207+
208+
# Hack to explicitly convert the analyses to PathfinderAnalysis objects because this doesn't work automatically. It should. Maybe move this into Messenger? FIXME
209+
i_analysis = 0
210+
for analysis_dict in response_data['message']['results'][0]['analyses']:
211+
analysis_obj = PathfinderAnalysis.from_dict(analysis_dict)
212+
self.response.envelope.message.results[0].analyses[i_analysis] = analysis_obj
213+
i_analysis += 1
214+
215+
else:
216+
self.response.debug(f"Applying Connect to Message with parameters {parameters}")
217+
218+
#### This will effectively call __connect_nodes() unless the user injects something else
219+
getattr(self, '_' + self.__class__.__name__ + '__' + parameters[
220+
'action'])() # thank you https://stackoverflow.com/questions/11649848/call-methods-by-string
221+
222+
#### Store the result into the cache for next time
223+
elapsed_time = time.time() - start
224+
self.response.info(f"Got result from ARAX PathFinder Connect after {elapsed_time}. Converting to_dict()")
225+
response_object = self.response.envelope.to_dict()
226+
self.response.info(f"Storing resulting dict in the cache")
227+
cacher.store_response(
228+
kp_curie=kp_curie,
229+
query_url=kp_url,
230+
query_object=pathfinder_input_data,
231+
response_object=response_object,
232+
http_code=200,
233+
elapsed_time=elapsed_time,
234+
status="OK"
235+
)
236+
self.response.info(f"Stored result in the cache.")
192237

193238
if self.report_stats: # helper to report information in debug if class self.report_stats = True
194239
self.response = self.report_response_stats(self.response)
195240
return self.response
196241

242+
243+
#### During processing, sometimes these parameters change from a string (of an integer) to an integer, so just force them all to strings for the purpose of cache comparison
244+
def _clean_parameters(self, parameters):
245+
cleaned_parameters = parameters.copy()
246+
cleaned_parameters['max_path_length'] = str(cleaned_parameters['max_path_length'])
247+
cleaned_parameters['max_pathfinder_paths'] = str(cleaned_parameters['max_pathfinder_paths'])
248+
return cleaned_parameters
249+
250+
197251
def get_pinned_nodes(self):
198252
pinned_nodes = []
199253
for key, node in self.message.query_graph.nodes.items():

code/ARAX/ARAXQuery/ARAX_expander.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from openapi_server.models.attribute import Attribute
3131
from openapi_server.models.retrieval_source import RetrievalSource
3232
from Expand.trapi_querier import TRAPIQuerier
33+
from Expand.trapi_query_cacher import KPQueryCacher
34+
from ARAX_messenger import ARAXMessenger
3335

3436
UNBOUND_NODES_KEY = "__UNBOUND__"
3537

@@ -329,6 +331,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
329331
# Get any inferred results from ARAX Infer
330332
if inferred_qedge_keys:
331333
response, overarching_kg = self.get_inferred_answers(inferred_qedge_keys, query_graph, response)
334+
#### Update the local message with a potentially new message created in previous method call
335+
message = response.envelope.message
332336
if log.status != 'OK':
333337
return response
334338
# Now mark qedges as 'lookup' if this is an inferred query
@@ -549,6 +553,18 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
549553
message.encountered_kryptonite_edges_info, response)
550554
# Remove any paths that are now dead-ends
551555
if inferred_qedge_keys and len(inferred_qedge_keys) == 1:
556+
557+
#### Write some state information to files for debugging
558+
debug_filepath = os.path.dirname(os.path.abspath(__file__))
559+
if hasattr(response, 'dtd_from_cache') and response.dtd_from_cache is True:
560+
debug_filepath += "/zz_cache_"
561+
else:
562+
debug_filepath += "/zz_fresh_"
563+
with open(debug_filepath + "query_graph.json", 'w') as outfile:
564+
print(f"*******line 564: message.query_graph={message.query_graph}", file=outfile)
565+
with open(debug_filepath + "overarching_kg.json", 'w') as outfile:
566+
print(f"*******line 566: overarching_kg={overarching_kg}", file=outfile)
567+
552568
overarching_kg = self._remove_dead_end_paths(message.query_graph, overarching_kg, response)
553569
else:
554570
overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response)
@@ -668,13 +684,61 @@ def get_inferred_answers(inferred_qedge_keys: list[str],
668684
infer_input_parameters = {"action": "drug_treatment_graph_expansion",
669685
'disease_curie': object_curie, 'qedge_id': inferred_qedge_key,
670686
'drug_curie': subject_curie}
671-
inferer = ARAXInfer()
672-
infer_response = inferer.apply(response, infer_input_parameters)
687+
688+
#### Check the cache to see if we have this query cached already
689+
cacher = KPQueryCacher()
690+
enable_caching = False
691+
kp_curie = "xDTD"
692+
kp_url = "xDTD"
693+
if enable_caching:
694+
response.info(f"Looking for a previously cached result from {kp_curie}")
695+
response_data, response_code, elapsed_time, error = cacher.get_cached_result(kp_curie, infer_input_parameters)
696+
else:
697+
response.info(f"KP results caching for xDTD is currently disabled, pending further debugging")
698+
if enable_caching and response_code != -2:
699+
n_results = cacher._get_n_results(response_data)
700+
response.info(f"Found a cached result with response_code={response_code}, n_results={n_results} from the cache in {elapsed_time:.3f} seconds")
701+
#### Transform the dict message into objects
702+
response.envelope.message = ARAXMessenger().from_dict(response_data['message'])
703+
response.envelope.message.encountered_kryptonite_edges_info = response_data['message']['encountered_kryptonite_edges_info']
704+
for node_key, node in response_data['message']['knowledge_graph']['nodes'].items():
705+
response.info(f"Copying qnode_keys for node {node_key}")
706+
response.envelope.message.knowledge_graph.nodes[node_key].qnode_keys = node['qnode_keys']
707+
for edge_key, edge in response_data['message']['knowledge_graph']['edges'].items():
708+
response.info(f"Copying qedge_keys for edge {edge_key}")
709+
response.envelope.message.knowledge_graph.edges[edge_key].qedge_keys = edge['qedge_keys']
710+
response.dtd_from_cache = True
711+
712+
#### Else run the inferer to get the result and then cache it
713+
else:
714+
inferer = ARAXInfer()
715+
response.info(f"Launching ARAX inferer")
716+
infer_response = inferer.apply(response, infer_input_parameters)
717+
elapsed_time = time.time() - start
718+
response.info(f"Got result from ARAX inferer after {elapsed_time}. Converting to_dict()")
719+
response_object = response.envelope.to_dict()
720+
response_object['message']['encountered_kryptonite_edges_info'] = response.envelope.message.encountered_kryptonite_edges_info
721+
for node_key, node in response.envelope.message.knowledge_graph.nodes.items():
722+
response_object['message']['knowledge_graph']['nodes'][node_key]['qnode_keys'] = node.qnode_keys
723+
for edge_key, edge in response.envelope.message.knowledge_graph.edges.items():
724+
response_object['message']['knowledge_graph']['edges'][edge_key]['qedge_keys'] = edge.qedge_keys
725+
response.info(f"Storing result in the cache")
726+
cacher.store_response(
727+
kp_curie=kp_curie,
728+
query_url=kp_url,
729+
query_object=infer_input_parameters,
730+
response_object=response_object,
731+
http_code=200,
732+
elapsed_time=elapsed_time,
733+
status="OK"
734+
)
735+
response.info(f"Stored result in the cache.")
736+
673737
# return infer_response
674-
response = infer_response
738+
#response = infer_response # these are already always the same object?
675739
overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(response.envelope.message.knowledge_graph)
676740

677-
wait_time = round(time.time() - start)
741+
wait_time = round(time.time() - start, 2)
678742
if response.status == "OK":
679743
done_message = f"Returned {len(overarching_kg.edges_by_qg_id.get(inferred_qedge_key, dict()))} " \
680744
f"edges in {wait_time} seconds"

0 commit comments

Comments
 (0)