|
30 | 30 | from openapi_server.models.attribute import Attribute |
31 | 31 | from openapi_server.models.retrieval_source import RetrievalSource |
32 | 32 | from Expand.trapi_querier import TRAPIQuerier |
| 33 | +from Expand.trapi_query_cacher import KPQueryCacher |
| 34 | +from ARAX_messenger import ARAXMessenger |
33 | 35 |
|
34 | 36 | UNBOUND_NODES_KEY = "__UNBOUND__" |
35 | 37 |
|
@@ -329,6 +331,8 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): |
329 | 331 | # Get any inferred results from ARAX Infer |
330 | 332 | if inferred_qedge_keys: |
331 | 333 | response, overarching_kg = self.get_inferred_answers(inferred_qedge_keys, query_graph, response) |
| 334 | + #### Update the local message with a potentially new message created in previous method call |
| 335 | + message = response.envelope.message |
332 | 336 | if log.status != 'OK': |
333 | 337 | return response |
334 | 338 | # Now mark qedges as 'lookup' if this is an inferred query |
@@ -549,6 +553,18 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): |
549 | 553 | message.encountered_kryptonite_edges_info, response) |
550 | 554 | # Remove any paths that are now dead-ends |
551 | 555 | if inferred_qedge_keys and len(inferred_qedge_keys) == 1: |
| 556 | + |
| 557 | + #### Write some state information to files for debugging |
| 558 | + debug_filepath = os.path.dirname(os.path.abspath(__file__)) |
| 559 | + if hasattr(response, 'dtd_from_cache') and response.dtd_from_cache is True: |
| 560 | + debug_filepath += "/zz_cache_" |
| 561 | + else: |
| 562 | + debug_filepath += "/zz_fresh_" |
| 563 | + with open(debug_filepath + "query_graph.json", 'w') as outfile: |
| 564 | + print(f"*******line 564: message.query_graph={message.query_graph}", file=outfile) |
| 565 | + with open(debug_filepath + "overarching_kg.json", 'w') as outfile: |
| 566 | + print(f"*******line 566: overarching_kg={overarching_kg}", file=outfile) |
| 567 | + |
552 | 568 | overarching_kg = self._remove_dead_end_paths(message.query_graph, overarching_kg, response) |
553 | 569 | else: |
554 | 570 | overarching_kg = self._remove_dead_end_paths(query_graph, overarching_kg, response) |
@@ -668,13 +684,61 @@ def get_inferred_answers(inferred_qedge_keys: list[str], |
668 | 684 | infer_input_parameters = {"action": "drug_treatment_graph_expansion", |
669 | 685 | 'disease_curie': object_curie, 'qedge_id': inferred_qedge_key, |
670 | 686 | 'drug_curie': subject_curie} |
671 | | - inferer = ARAXInfer() |
672 | | - infer_response = inferer.apply(response, infer_input_parameters) |
| 687 | + |
| 688 | + #### Check the cache to see if we have this query cached already |
| 689 | + cacher = KPQueryCacher() |
| 690 | + enable_caching = False |
| 691 | + kp_curie = "xDTD" |
| 692 | + kp_url = "xDTD" |
| 693 | + if enable_caching: |
| 694 | + response.info(f"Looking for a previously cached result from {kp_curie}") |
| 695 | + response_data, response_code, elapsed_time, error = cacher.get_cached_result(kp_curie, infer_input_parameters) |
| 696 | + else: |
| 697 | + response.info(f"KP results caching for xDTD is currently disabled, pending further debugging") |
| 698 | + if enable_caching and response_code != -2: |
| 699 | + n_results = cacher._get_n_results(response_data) |
| 700 | + response.info(f"Found a cached result with response_code={response_code}, n_results={n_results} from the cache in {elapsed_time:.3f} seconds") |
| 701 | + #### Transform the dict message into objects |
| 702 | + response.envelope.message = ARAXMessenger().from_dict(response_data['message']) |
| 703 | + response.envelope.message.encountered_kryptonite_edges_info = response_data['message']['encountered_kryptonite_edges_info'] |
| 704 | + for node_key, node in response_data['message']['knowledge_graph']['nodes'].items(): |
| 705 | + response.info(f"Copying qnode_keys for node {node_key}") |
| 706 | + response.envelope.message.knowledge_graph.nodes[node_key].qnode_keys = node['qnode_keys'] |
| 707 | + for edge_key, edge in response_data['message']['knowledge_graph']['edges'].items(): |
| 708 | + response.info(f"Copying qedge_keys for edge {edge_key}") |
| 709 | + response.envelope.message.knowledge_graph.edges[edge_key].qedge_keys = edge['qedge_keys'] |
| 710 | + response.dtd_from_cache = True |
| 711 | + |
| 712 | + #### Else run the inferer to get the result and then cache it |
| 713 | + else: |
| 714 | + inferer = ARAXInfer() |
| 715 | + response.info(f"Launching ARAX inferer") |
| 716 | + infer_response = inferer.apply(response, infer_input_parameters) |
| 717 | + elapsed_time = time.time() - start |
| 718 | + response.info(f"Got result from ARAX inferer after {elapsed_time}. Converting to_dict()") |
| 719 | + response_object = response.envelope.to_dict() |
| 720 | + response_object['message']['encountered_kryptonite_edges_info'] = response.envelope.message.encountered_kryptonite_edges_info |
| 721 | + for node_key, node in response.envelope.message.knowledge_graph.nodes.items(): |
| 722 | + response_object['message']['knowledge_graph']['nodes'][node_key]['qnode_keys'] = node.qnode_keys |
| 723 | + for edge_key, edge in response.envelope.message.knowledge_graph.edges.items(): |
| 724 | + response_object['message']['knowledge_graph']['edges'][edge_key]['qedge_keys'] = edge.qedge_keys |
| 725 | + response.info(f"Storing result in the cache") |
| 726 | + cacher.store_response( |
| 727 | + kp_curie=kp_curie, |
| 728 | + query_url=kp_url, |
| 729 | + query_object=infer_input_parameters, |
| 730 | + response_object=response_object, |
| 731 | + http_code=200, |
| 732 | + elapsed_time=elapsed_time, |
| 733 | + status="OK" |
| 734 | + ) |
| 735 | + response.info(f"Stored result in the cache.") |
| 736 | + |
673 | 737 | # return infer_response |
674 | | - response = infer_response |
| 738 | + #response = infer_response # these are already always the same object? |
675 | 739 | overarching_kg = eu.convert_standard_kg_to_qg_organized_kg(response.envelope.message.knowledge_graph) |
676 | 740 |
|
677 | | - wait_time = round(time.time() - start) |
| 741 | + wait_time = round(time.time() - start, 2) |
678 | 742 | if response.status == "OK": |
679 | 743 | done_message = f"Returned {len(overarching_kg.edges_by_qg_id.get(inferred_qedge_key, dict()))} " \ |
680 | 744 | f"edges in {wait_time} seconds" |
|
0 commit comments