diff --git a/.gitignore b/.gitignore index 7427d5fd9..eae16ccee 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ code/ARAX/ARAXQuery/Infer/data/xCRG_data/*.npz code/UI/OpenAPI/python-flask-server/openapi_server/openapi/openapi.json code/UI/OpenAPI/specifications/export/KG2/*/openapi.json +**/cache_refresh.pid diff --git a/code/ARAX/ARAXQuery/ARAX_expander.py b/code/ARAX/ARAXQuery/ARAX_expander.py index f587e0930..4804f2ded 100644 --- a/code/ARAX/ARAXQuery/ARAX_expander.py +++ b/code/ARAX/ARAXQuery/ARAX_expander.py @@ -7,7 +7,7 @@ import time import traceback from collections import defaultdict -from typing import List, Dict, Tuple, Union, Set, Optional +from typing import Union, Optional, Any sys.path.append(os.path.dirname(os.path.abspath(__file__))) # ARAXQuery directory from ARAX_response import ARAXResponse @@ -31,6 +31,7 @@ from openapi_server.models.retrieval_source import RetrievalSource from Expand.trapi_querier import TRAPIQuerier +UNBOUND_NODES_KEY = "__UNBOUND__" def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) @@ -147,7 +148,11 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): message.knowledge_graph = KnowledgeGraph(nodes=dict(), edges=dict()) log = response # Fetch the list of all registered kps with compatible versions - kp_selector = KPSelector(log=log) + try: + kp_selector = KPSelector(log=log) + except ValueError as e: + response.error(str(e)) + return response # Save the original QG, if it hasn't already been saved in ARAXQuery (happens for DSL queries..) if not hasattr(response, "original_query_graph"): @@ -597,9 +602,9 @@ def apply(self, response, input_parameters, mode: str = "ARAX"): return response @staticmethod - def get_inferred_answers(inferred_qedge_keys: List[str], + def get_inferred_answers(inferred_qedge_keys: list[str], query_graph: QueryGraph, - response: ARAXResponse) -> Tuple[ARAXResponse, QGOrganizedKnowledgeGraph]: + response: ARAXResponse) -> tuple[ARAXResponse, QGOrganizedKnowledgeGraph]: # Send ARAXInfer any one-hop, "inferred", "treats" queries (temporary way of making creative mode work) overarching_kg = QGOrganizedKnowledgeGraph() if inferred_qedge_keys: @@ -745,7 +750,7 @@ async def _expand_edge_async(self, edge_qg: QueryGraph, kp_selector: KPSelector, log: ARAXResponse, multiple_kps: bool = False, - be_creative_treats: bool = False) -> Tuple[QGOrganizedKnowledgeGraph, ARAXResponse]: + be_creative_treats: bool = False) -> tuple[QGOrganizedKnowledgeGraph, ARAXResponse]: # This function answers a single-edge (one-hop) query using the specified knowledge provider qedge_key = next(qedge_key for qedge_key in edge_qg.edges) log.info(f"Expanding qedge {qedge_key} using {kp_to_use}") @@ -815,7 +820,7 @@ async def _expand_edge_async(self, edge_qg: QueryGraph, @staticmethod def _expand_node(qnode_key: str, - kps_to_use: List[str], + kps_to_use: list[str], query_graph: QueryGraph, user_specified_kp: bool, kp_timeout: Optional[int], @@ -901,8 +906,8 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log: edges={qedge_key: dict() for qedge_key in answer_kg.edges_by_qg_id}) curie_mappings = dict() - # First deduplicate the nodes - for qnode_key, nodes in answer_kg.nodes_by_qg_id.items(): + # First deduplicate the bound nodes + for qnode_key, nodes in {**answer_kg.nodes_by_qg_id, UNBOUND_NODES_KEY: answer_kg.unbound_nodes}.items(): # Load preferred curie info from NodeSynonymizer log.debug(f"{kp_name}: Getting preferred curies for {qnode_key} nodes returned in this step") canonicalized_nodes = eu.get_canonical_curies_dict(list(nodes), log) if nodes else dict() @@ -927,11 +932,17 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log: curie_mappings[node_key] = preferred_curie # Add this node into our deduplicated KG as necessary - if preferred_curie not in deduplicated_kg.nodes_by_qg_id[qnode_key]: - node_key = preferred_curie - node.name = preferred_name - node.categories = preferred_categories - deduplicated_kg.add_node(node_key, node, qnode_key) + if qnode_key != UNBOUND_NODES_KEY: + if preferred_curie not in deduplicated_kg.nodes_by_qg_id[qnode_key]: + node_key = preferred_curie + node.name = preferred_name + node.categories = preferred_categories + deduplicated_kg.add_node(node_key, node, qnode_key) + else: # this is an unbound node + if preferred_curie not in deduplicated_kg.unbound_nodes: + node.name = preferred_name + node.categories = preferred_categories + deduplicated_kg.unbound_nodes[preferred_curie] = node # Then update the edges to reflect changes made to the nodes dropped_edge_count = dict() @@ -957,7 +968,7 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log: return deduplicated_kg, dropped_edge_count @staticmethod - def _extract_query_subgraph(qedge_keys_to_expand: List[str], query_graph: QueryGraph, log: ARAXResponse) -> QueryGraph: + def _extract_query_subgraph(qedge_keys_to_expand: list[str], query_graph: QueryGraph, log: ARAXResponse) -> QueryGraph: # This function extracts a sub-query graph containing the provided qedge IDs from a larger query graph sub_query_graph = QueryGraph(nodes=dict(), edges=dict()) @@ -1059,7 +1070,7 @@ def _merge_answer_into_message_kg(answer_kg: QGOrganizedKnowledgeGraph, overarch @staticmethod def _store_kryptonite_edge_info(kryptonite_kg: QGOrganizedKnowledgeGraph, kryptonite_qedge_key: str, qg: QueryGraph, - encountered_kryptonite_edges_info: Dict[str, Dict[str, Set[str]]], log: ARAXResponse): + encountered_kryptonite_edges_info: dict[str, dict[str, set[str]]], log: ARAXResponse): """ This function adds the IDs of nodes found by expansion of the given kryptonite ("not") edge to the global encountered_kryptonite_edges_info dictionary, which is organized by QG IDs. This allows Expand to "remember" @@ -1087,7 +1098,7 @@ def _store_kryptonite_edge_info(kryptonite_kg: QGOrganizedKnowledgeGraph, krypto @staticmethod def _apply_any_kryptonite_edges(organized_kg: QGOrganizedKnowledgeGraph, full_query_graph: QueryGraph, - encountered_kryptonite_edges_info: Dict[str, Dict[str, Set[str]]], log): + encountered_kryptonite_edges_info: dict[str, dict[str, set[str]]], log): """ This function breaks any paths in the KG for which a "not" (exclude=True) condition has been met; the remains of the broken paths not used in other paths in the KG are cleaned up during later pruning of dead ends. The @@ -1148,8 +1159,11 @@ def _prune_kg(qnode_key_to_prune: str, prune_threshold: int, kg: QGOrganizedKnow num_edges_in_kg = sum([len(edges) for edges in kg.edges_by_qg_id.values()]) overlay_fet = True if num_edges_in_kg < 100000 else False # Use fisher exact test and the ranker to prune down answers for this qnode - intermediate_results_response = eu.create_results(qg_for_resultify, kg_copy, log, - rank_results=True, overlay_fet=overlay_fet, + intermediate_results_response = eu.create_results(qg_for_resultify, + kg_copy, + log, + rank_results=True, + overlay_fet=overlay_fet, qnode_key_to_prune=qnode_key_to_prune) log.debug(f"A total of {len(intermediate_results_response.envelope.message.results)} " f"intermediate results were created/ranked") @@ -1157,7 +1171,7 @@ def _prune_kg(qnode_key_to_prune: str, prune_threshold: int, kg: QGOrganizedKnow # Filter down so we only keep the top X nodes results = intermediate_results_response.envelope.message.results results.sort(key=lambda x: x.analyses[0].score, reverse=True) - kept_nodes = set() + kept_nodes: set[str] = set() scores = [] counter = 0 while len(kept_nodes) < prune_threshold and counter < len(results): @@ -1203,7 +1217,7 @@ def _remove_dead_end_paths(expands_qg: QueryGraph, kg: QGOrganizedKnowledgeGraph @staticmethod def _add_node_connection_to_map(qnode_key_a: str, qnode_key_b: str, edge: Edge, - node_connections_map: Dict[str, Dict[str, Dict[str, Set[str]]]]): + node_connections_map: dict[str, dict[str, dict[str, set[str]]]]): # This is a helper function that's used when building a map of which nodes are connected to which # Example node_connections_map: {'n01': {'UMLS:1222': {'n00': {'DOID:122'}, 'n02': {'UniProtKB:22'}}}, ...} # Initiate entries for this edge's nodes as needed @@ -1220,14 +1234,14 @@ def _add_node_connection_to_map(qnode_key_a: str, qnode_key_b: str, edge: Edge, node_connections_map[qnode_key_a][edge.subject][qnode_key_b].add(edge.object) node_connections_map[qnode_key_b][edge.object][qnode_key_a].add(edge.subject) - def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXResponse) -> List[str]: + def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXResponse) -> list[str]: """ This function determines what order to expand the edges in a query graph in; it aims to start with a required, non-kryptonite qedge that has a qnode with a curie specified. It then looks for a qedge connected to that starting qedge, and so on. """ qedge_keys_remaining = [qedge_key for qedge_key in query_graph.edges] - ordered_qedge_keys = [] + ordered_qedge_keys: list[str] = [] while qedge_keys_remaining: if not ordered_qedge_keys: # Try to start with a required, non-kryptonite qedge that has a qnode with a curie specified @@ -1258,7 +1272,7 @@ def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXRespo return ordered_qedge_keys @staticmethod - def _find_qedge_connected_to_subgraph(subgraph_qedge_keys: List[str], qedge_keys_to_choose_from: List[str], + def _find_qedge_connected_to_subgraph(subgraph_qedge_keys: list[str], qedge_keys_to_choose_from: list[str], qg: QueryGraph) -> Optional[str]: qnode_keys_in_subgraph = {qnode_key for qedge_key in subgraph_qedge_keys for qnode_key in {qg.edges[qedge_key].subject, qg.edges[qedge_key].object}} @@ -1322,9 +1336,9 @@ def _remove_self_edges(kg: QGOrganizedKnowledgeGraph, kp_name: str, log: ARAXRes log.debug(f"{kp_name}: After removing self-edges, answer KG counts are: {eu.get_printable_counts_by_qg_id(kg)}") return kg - def _set_and_validate_parameters(self, input_parameters: Dict[str, any], kp_selector: KPSelector, - log: ARAXResponse) -> Dict[str, any]: - parameters = dict() + def _set_and_validate_parameters(self, input_parameters: dict[str, Any], kp_selector: KPSelector, + log: ARAXResponse) -> dict[str, Any]: + parameters: dict[str, Any] = dict() parameter_info_dict = self.get_parameter_info_dict() # First set parameters to their defaults @@ -1359,7 +1373,7 @@ def _set_and_validate_parameters(self, input_parameters: Dict[str, any], kp_sele return parameters @staticmethod - def is_supported_constraint(constraint: AttributeConstraint, supported_constraints_map: Dict[str, Dict[str, Set[str]]]) -> bool: + def is_supported_constraint(constraint: AttributeConstraint, supported_constraints_map: dict[str, dict[str, set[str]]]) -> bool: if constraint.id not in supported_constraints_map: return False if constraint.operator not in supported_constraints_map[constraint.id]: @@ -1371,7 +1385,7 @@ def is_supported_constraint(constraint: AttributeConstraint, supported_constrain return True @staticmethod - def _load_fda_approved_drug_ids() -> Set[str]: + def _load_fda_approved_drug_ids() -> set[str]: # Determine the local path to the FDA-approved drugs pickle path_list = os.path.realpath(__file__).split(os.path.sep) rtx_index = path_list.index("RTX") @@ -1467,12 +1481,12 @@ def _get_orphan_qnode_keys(query_graph: QueryGraph): return list(all_qnode_keys.difference(qnode_keys_used_by_qedges)) @staticmethod - def _get_qedges_with_curie_qnode(query_graph: QueryGraph) -> List[str]: + def _get_qedges_with_curie_qnode(query_graph: QueryGraph) -> list[str]: return [qedge_key for qedge_key, qedge in query_graph.edges.items() if query_graph.nodes[qedge.subject].ids or query_graph.nodes[qedge.object].ids] @staticmethod - def _find_connected_qedge(qedge_choices: List[QEdge], qedge: QEdge) -> QEdge: + def _find_connected_qedge(qedge_choices: list[QEdge], qedge: QEdge) -> QEdge: qedge_qnode_keys = {qedge.subject, qedge.object} connected_qedges = [] for other_qedge in qedge_choices: diff --git a/code/ARAX/ARAXQuery/ARAX_response.py b/code/ARAX/ARAXQuery/ARAX_response.py index e0390bb9f..62d2cb867 100644 --- a/code/ARAX/ARAXQuery/ARAX_response.py +++ b/code/ARAX/ARAXQuery/ARAX_response.py @@ -36,8 +36,10 @@ def __init__(self, status='OK', logging_level=WARNING, error_code='OK', message= self.n_warnings = 0 self.data = {} self.envelope = None - self.query_plan = { 'qedge_keys': {}, 'counter': 0 } + self.wait_time = None # this attribute is set by trapi_querier.py + self.http_error = None # this attribute is set by trapi_querier.py + self.timed_out = None # this attribute is set by trapi_querier.py #### Add a debugging message diff --git a/code/ARAX/ARAXQuery/Expand/expand_utilities.py b/code/ARAX/ARAXQuery/Expand/expand_utilities.py index 85e46b923..4d1705e06 100644 --- a/code/ARAX/ARAXQuery/Expand/expand_utilities.py +++ b/code/ARAX/ARAXQuery/Expand/expand_utilities.py @@ -5,7 +5,7 @@ import os import traceback import yaml -from typing import List, Dict, Union, Set, Tuple, Optional +from typing import Union, Optional sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../../../UI/OpenAPI/python-flask-server/") from openapi_server.models.knowledge_graph import KnowledgeGraph @@ -36,12 +36,16 @@ class QGOrganizedKnowledgeGraph: - def __init__(self, nodes: Dict[str, Dict[str, Node]] = None, edges: Dict[str, Dict[str, Edge]] = None): + def __init__(self, + nodes: Optional[dict[str, dict[str, Node]]] = None, + edges: Optional[dict[str, dict[str, Edge]]] = None, + unbound_nodes: Optional[dict[str, Node]] = None): self.nodes_by_qg_id = nodes if nodes else dict() self.edges_by_qg_id = edges if edges else dict() + self.unbound_nodes = unbound_nodes if unbound_nodes else dict() def __str__(self): - return f"nodes_by_qg_id:\n{self.nodes_by_qg_id}\nedges_by_qg_id:\n{self.edges_by_qg_id}" + return f"nodes_by_qg_id:\n{self.nodes_by_qg_id}\nedges_by_qg_id:\n{self.edges_by_qg_id}\nunbound_nodes:\n{self.unbound_nodes}" def add_node(self, node_key: str, node: Node, qnode_key: str): if qnode_key not in self.nodes_by_qg_id: @@ -72,7 +76,7 @@ def add_edge(self, edge_key: str, edge: Edge, qedge_key: str): self.edges_by_qg_id[qedge_key] = dict() self.edges_by_qg_id[qedge_key][edge_key] = edge - def remove_nodes(self, node_keys_to_delete: Set[str], target_qnode_key: str, qg: QueryGraph): + def remove_nodes(self, node_keys_to_delete: set[str], target_qnode_key: str, qg: QueryGraph): # First delete the specified nodes for node_key in node_keys_to_delete: del self.nodes_by_qg_id[target_qnode_key][node_key] @@ -96,16 +100,16 @@ def remove_nodes(self, node_keys_to_delete: Set[str], target_qnode_key: str, qg: for orphan_node_key in orphan_node_keys: del self.nodes_by_qg_id[non_orphan_qnode_key][orphan_node_key] - def get_all_node_keys_used_by_edges(self) -> Set[str]: + def get_all_node_keys_used_by_edges(self) -> set[str]: return {node_key for edges in self.edges_by_qg_id.values() for edge in edges.values() for node_key in [edge.subject, edge.object]} - def get_node_keys_used_by_edges_fulfilling_qedge(self, qedge_key: str) -> Set[str]: + def get_node_keys_used_by_edges_fulfilling_qedge(self, qedge_key: str) -> set[str]: relevant_edges = self.edges_by_qg_id.get(qedge_key, dict()) return {node_key for edge in relevant_edges.values() for node_key in [edge.subject, edge.object]} - def get_all_node_keys(self) -> Set[str]: + def get_all_node_keys(self) -> set[str]: return {node_key for nodes in self.nodes_by_qg_id.values() for node_key in nodes} def is_empty(self) -> bool: @@ -169,7 +173,7 @@ def convert_string_to_snake_case(input_string: str) -> str: return input_string.lower() -def convert_to_list(item: Union[str, set, list, None]) -> List[str]: +def convert_to_list(item: Union[str, set, list, None]) -> list[str]: if isinstance(item, str): return [item] elif isinstance(item, set): @@ -180,16 +184,16 @@ def convert_to_list(item: Union[str, set, list, None]) -> List[str]: return [] -def convert_to_set(item: Union[str, set, list, None]) -> Set[str]: +def convert_to_set(item: Union[str, set, list, None]) -> set[str]: item_list = convert_to_list(item) return set(item_list) -def get_node_keys_used_by_edges(edges_dict: Dict[str, Edge]) -> Set[str]: +def get_node_keys_used_by_edges(edges_dict: dict[str, Edge]) -> set[str]: return {node_key for edge in edges_dict.values() for node_key in [edge.subject, edge.object]} -def get_counts_by_qg_id(dict_kg: QGOrganizedKnowledgeGraph) -> Dict[str, int]: +def get_counts_by_qg_id(dict_kg: QGOrganizedKnowledgeGraph) -> dict[str, int]: counts_by_qg_id = dict() for qnode_key, nodes_dict in dict_kg.nodes_by_qg_id.items(): counts_by_qg_id[qnode_key] = len(nodes_dict) @@ -201,7 +205,7 @@ def get_counts_by_qg_id(dict_kg: QGOrganizedKnowledgeGraph) -> Dict[str, int]: def get_printable_counts_by_qg_id(dict_kg: QGOrganizedKnowledgeGraph) -> str: counts_by_qg_id = get_counts_by_qg_id(dict_kg) counts_string = ", ".join([f"{qg_id}: {counts_by_qg_id[qg_id]}" for qg_id in sorted(counts_by_qg_id)]) - return counts_string if counts_string else "no answers" + return (counts_string + f", Unbound: {len(dict_kg.unbound_nodes)}") if counts_string else "no answers" def get_qg_without_kryptonite_portion(qg: QueryGraph) -> QueryGraph: @@ -278,8 +282,10 @@ def convert_qg_organized_kg_to_standard_kg(organized_kg: QGOrganizedKnowledgeGra return standard_kg -def get_curie_synonyms(curie: Union[str, List[str]], log: Optional[ARAXResponse] = ARAXResponse()) -> List[str]: +def get_curie_synonyms(curie: Union[str, list[str]], log: Optional[ARAXResponse]) -> list[str]: curies = convert_to_list(curie) + if log is None: + log = ARAXResponse() try: synonymizer = NodeSynonymizer() log.debug(f"Sending NodeSynonymizer.get_equivalent_nodes() a list of {len(curies)} curies") @@ -288,7 +294,8 @@ def get_curie_synonyms(curie: Union[str, List[str]], log: Optional[ARAXResponse] except Exception: tb = traceback.format_exc() error_type, error, _ = sys.exc_info() - log.error(f"Encountered a problem using NodeSynonymizer: {tb}", error_code=error_type.__name__) + log.error(f"Encountered a problem using NodeSynonymizer: {tb}", + error_code=error_type.__name__) # type: ignore[union-attr] return [] else: if equivalent_curies_dict is not None: @@ -304,8 +311,11 @@ def get_curie_synonyms(curie: Union[str, List[str]], log: Optional[ARAXResponse] return [] -def get_curie_synonyms_dict(curie: Union[str, List[str]], log: Optional[ARAXResponse] = ARAXResponse()) -> Dict[str, List[str]]: +def get_curie_synonyms_dict(curie: Union[str, list[str]], + log: Optional[ARAXResponse] = None) -> dict[str, list[str]]: curies = convert_to_list(curie) + if log is None: + log = ARAXResponse() try: synonymizer = NodeSynonymizer() log.debug(f"Sending NodeSynonymizer.get_equivalent_nodes() a list of {len(curies)} curies") @@ -314,7 +324,8 @@ def get_curie_synonyms_dict(curie: Union[str, List[str]], log: Optional[ARAXResp except Exception: tb = traceback.format_exc() error_type, error, _ = sys.exc_info() - log.error(f"Encountered a problem using NodeSynonymizer: {tb}", error_code=error_type.__name__) + log.error(f"Encountered a problem using NodeSynonymizer: {tb}", + error_code=error_type.__name__) # type: ignore[union-attr] return dict() else: if equivalent_curies_dict is not None: @@ -331,7 +342,7 @@ def get_curie_synonyms_dict(curie: Union[str, List[str]], log: Optional[ARAXResp return dict() -def get_canonical_curies_dict(curie: Union[str, List[str]], log: ARAXResponse) -> Dict[str, Dict[str, str]]: +def get_canonical_curies_dict(curie: Union[str, list[str]], log: ARAXResponse) -> dict[str, dict[str, str]]: curies = convert_to_list(curie) try: synonymizer = NodeSynonymizer() @@ -341,7 +352,8 @@ def get_canonical_curies_dict(curie: Union[str, List[str]], log: ARAXResponse) - except Exception: tb = traceback.format_exc() error_type, error, _ = sys.exc_info() - log.error(f"Encountered a problem using NodeSynonymizer: {tb}", error_code=error_type.__name__) + log.error(f"Encountered a problem using NodeSynonymizer: {tb}", + error_code=error_type.__name__) # type: ignore[union-attr] return {} else: if canonical_curies_dict is not None: @@ -354,7 +366,7 @@ def get_canonical_curies_dict(curie: Union[str, List[str]], log: ARAXResponse) - return {} -def get_canonical_curies_list(curie: Union[str, List[str]], log: ARAXResponse) -> List[str]: +def get_canonical_curies_list(curie: Union[str, list[str]], log: ARAXResponse) -> list[str]: curies = convert_to_list(curie) try: synonymizer = NodeSynonymizer() @@ -364,7 +376,8 @@ def get_canonical_curies_list(curie: Union[str, List[str]], log: ARAXResponse) - except Exception: tb = traceback.format_exc() error_type, error, _ = sys.exc_info() - log.error(f"Encountered a problem using NodeSynonymizer: {tb}", error_code=error_type.__name__) + log.error(f"Encountered a problem using NodeSynonymizer: {tb}", + error_code=error_type.__name__) # type: ignore[union-attr] return [] else: if canonical_curies_dict is not None: @@ -383,7 +396,7 @@ def get_canonical_curies_list(curie: Union[str, List[str]], log: ARAXResponse) - return [] -def get_preferred_categories(curie: Union[str, List[str]], log: ARAXResponse) -> Optional[List[str]]: +def get_preferred_categories(curie: Union[str, list[str]], log: ARAXResponse) -> Optional[list[str]]: curies = convert_to_list(curie) synonymizer = NodeSynonymizer() log.debug(f"Sending NodeSynonymizer.get_canonical_curies() a list of {len(curies)} curies") @@ -406,7 +419,7 @@ def get_preferred_categories(curie: Union[str, List[str]], log: ARAXResponse) -> return [] -def get_curie_names(curie: Union[str, List[str]], log: ARAXResponse) -> Dict[str, str]: +def get_curie_names(curie: Union[str, list[str]], log: ARAXResponse) -> dict[str, str]: curies = convert_to_list(curie) synonymizer = NodeSynonymizer() log.debug(f"Looking up names for {len(curies)} input curies using NodeSynonymizer.get_curie_names()") @@ -414,8 +427,11 @@ def get_curie_names(curie: Union[str, List[str]], log: ARAXResponse) -> Dict[str return curie_to_name_map -def qg_is_fulfilled(query_graph: QueryGraph, dict_kg: QGOrganizedKnowledgeGraph, enforce_required_only=False, - enforce_expanded_only=False, return_unfulfilled_qedges: bool = False) -> any: +def qg_is_fulfilled(query_graph: QueryGraph, + dict_kg: QGOrganizedKnowledgeGraph, + enforce_required_only=False, + enforce_expanded_only=False, + return_unfulfilled_qedges: bool = False) -> tuple[bool, Union[bool, set[str]]]: if enforce_required_only: qg_without_kryptonite_portion = get_qg_without_kryptonite_portion(query_graph) query_graph = get_required_portion_of_qg(qg_without_kryptonite_portion) @@ -432,7 +448,7 @@ def qg_is_fulfilled(query_graph: QueryGraph, dict_kg: QGOrganizedKnowledgeGraph, for qnode_key in query_graph.nodes: if not dict_kg.nodes_by_qg_id.get(qnode_key): is_fulfilled = False - unfulfilled_qedge_keys = set() + unfulfilled_qedge_keys: set[str] = set() for qedge_key, qedge in query_graph.edges.items(): if not dict_kg.edges_by_qg_id.get(qedge_key): unfulfilled_qedge_keys.add(qedge_key) @@ -453,7 +469,7 @@ def qg_is_disconnected(qg: QueryGraph) -> bool: return True if not connected_qnode_key and qnode_keys_remaining else False -def find_qnode_connected_to_sub_qg(qnode_keys_to_connect_to: Set[str], qnode_keys_to_choose_from: Set[str], qg: QueryGraph) -> Tuple[str, Set[str]]: +def find_qnode_connected_to_sub_qg(qnode_keys_to_connect_to: set[str], qnode_keys_to_choose_from: set[str], qg: QueryGraph) -> tuple[str, set[str]]: """ This function selects a qnode ID from the qnode_keys_to_choose_from that connects to one or more of the qnode IDs in the qnode_keys_to_connect_to (which itself could be considered a sub-graph of the QG). It also returns the IDs @@ -469,7 +485,7 @@ def find_qnode_connected_to_sub_qg(qnode_keys_to_connect_to: Set[str], qnode_key return "", set() -def get_connected_qedge_keys(qnode_key: str, qg: QueryGraph) -> Set[str]: +def get_connected_qedge_keys(qnode_key: str, qg: QueryGraph) -> set[str]: return {qedge_key for qedge_key, qedge in qg.edges.items() if qnode_key in {qedge.subject, qedge.object}} @@ -485,7 +501,7 @@ def flip_edge(edge: Edge, new_predicate: str) -> Edge: return edge -def flip_qedge(qedge: QEdge, new_predicates: List[str]): +def flip_qedge(qedge: QEdge, new_predicates: list[str]): qedge.predicates = new_predicates original_subject = qedge.subject qedge.subject = qedge.object @@ -517,7 +533,7 @@ def get_computed_value_attribute() -> Attribute: "directly attachable to other edges.") -def sort_kps_for_asyncio(kp_names: Union[List[str], Set[str]], log: ARAXResponse) -> List[str]: +def sort_kps_for_asyncio(kp_names: Union[list[str], set[str]], log: ARAXResponse) -> list[str]: # Our in-house KPs block the multi-threading, because there's no request to wait for; so we process them first kp_names = set(kp_names) to_call_first = ["infores:arax-drug-treats-disease", "infores:arax-normalized-google-distance"] @@ -589,7 +605,7 @@ def remove_semmeddb_edges_and_nodes_with_low_publications(kg: KnowledgeGraph, except: tb = traceback.format_exc() error_type, error, _ = sys.exc_info() - log.error(tb, error_code=error_type.__name__) + log.error(tb, error_code=error_type.__name__) # type: ignore[union-attr] log.error(f"Something went wrong removing semmeddb edges from the knowledge graph") else: log.info(f"{edges_removed_counter} Semmeddb Edges with low publication count successfully removed") diff --git a/code/ARAX/ARAXQuery/Expand/kp_info_cacher.py b/code/ARAX/ARAXQuery/Expand/kp_info_cacher.py index b0c2f1359..09a894194 100644 --- a/code/ARAX/ARAXQuery/Expand/kp_info_cacher.py +++ b/code/ARAX/ARAXQuery/Expand/kp_info_cacher.py @@ -8,7 +8,7 @@ import pickle import sys from datetime import datetime, timedelta -from typing import Set, Dict, Optional +from typing import Optional import requests import requests_cache @@ -151,7 +151,7 @@ def load_kp_info_caches(self, log: ARAXResponse): # --------------------------------- METHODS FOR BUILDING META MAP ----------------------------------------------- # # --- Note: These methods can't go in KPSelector because it would create a circular dependency with this class -- # - def _build_meta_map(self, allowed_kps_dict: Dict[str, str]): + def _build_meta_map(self, allowed_kps_dict: dict[str, str]): # Start with whatever pre-existing meta map we might already have (can use this info in case an API fails) cache_file = pathlib.Path(self.smart_api_and_meta_map_cache ) if cache_file.exists(): @@ -201,7 +201,7 @@ def _build_meta_map(self, allowed_kps_dict: Dict[str, str]): @staticmethod def _convert_meta_kg_to_meta_map(kp_meta_kg: dict) -> dict: - kp_meta_map = dict() + kp_meta_map: dict[str, dict[str, set[str]]] = dict() for meta_edge in kp_meta_kg["edges"]: subject_category = meta_edge["subject"] object_category = meta_edge["object"] diff --git a/code/ARAX/ARAXQuery/Expand/kp_selector.py b/code/ARAX/ARAXQuery/Expand/kp_selector.py index 692520716..f8f74280b 100644 --- a/code/ARAX/ARAXQuery/Expand/kp_selector.py +++ b/code/ARAX/ARAXQuery/Expand/kp_selector.py @@ -2,7 +2,7 @@ import os import pprint import sys -from typing import Set, List, Optional +from typing import Optional from collections import defaultdict from itertools import product @@ -26,6 +26,8 @@ def __init__(self, kg2_mode: bool = False, log: ARAXResponse = ARAXResponse()): self.kg2_mode = kg2_mode self.kp_cacher = KPInfoCacher() self.meta_map, self.kp_urls, self.kps_excluded_by_version, self.kps_excluded_by_maturity = self._load_cached_kp_info() + if (not self.kg2_mode) and (self.kp_urls is None): + raise ValueError("KP info cache has not been filled and we are not in KG2 mode; cannot initialize KP selector") self.valid_kps = {"infores:rtx-kg2"} if self.kg2_mode else set(self.kp_urls.keys()) self.bh = BiolinkHelper() @@ -45,7 +47,7 @@ def _load_cached_kp_info(self) -> tuple: return (meta_map, smart_api_info["allowed_kp_urls"], smart_api_info["kps_excluded_by_version"], smart_api_info["kps_excluded_by_maturity"]) - def get_kps_for_single_hop_qg(self, qg: QueryGraph) -> Optional[Set[str]]: + def get_kps_for_single_hop_qg(self, qg: QueryGraph) -> Optional[set[str]]: """ This function returns the names of the KPs that say they can answer the given one-hop query graph (based on the categories/predicates the QG uses). @@ -123,7 +125,7 @@ def kp_accepts_single_hop_qg(self, qg: QueryGraph, kp: str) -> Optional[bool]: return kp_accepts - def get_desirable_equivalent_curies(self, curies: List[str], categories: Optional[List[str]], kp: str) -> List[str]: + def get_desirable_equivalent_curies(self, curies: list[str], categories: Optional[list[str]], kp: str) -> list[str]: """ For each input curie, this function returns an equivalent curie(s) that uses a prefix the KP supports. """ @@ -139,8 +141,8 @@ def get_desirable_equivalent_curies(self, curies: List[str], categories: Optiona supported_prefixes = self._get_supported_prefixes(eu.convert_to_list(categories), kp) self.log.debug(f"{kp}: Prefixes {kp} supports for categories {categories} (and descendants) are: " f"{supported_prefixes}") - converted_curies = set() - unsupported_curies = set() + converted_curies: set[str] = set() + unsupported_curies: set[str] = set() synonyms_dict = eu.get_curie_synonyms_dict(curies) # Convert each input curie to a preferred, supported prefix for input_curie, equivalent_curies in synonyms_dict.items(): @@ -199,16 +201,16 @@ def make_qg_use_supported_prefixes(self, qg: QueryGraph, kp_name: str, log: ARAX def _get_uppercase_prefix(curie: str) -> str: return curie.split(":")[0].upper() - def _get_supported_prefixes(self, categories: List[str], kp: str) -> Set[str]: + def _get_supported_prefixes(self, categories: list[str], kp: str) -> set[str]: categories_with_descendants = self.bh.get_descendants(eu.convert_to_list(categories), include_mixins=False) supported_prefixes = {prefix.upper() for category in categories_with_descendants for prefix in self.meta_map[kp]["prefixes"].get(category, set())} return supported_prefixes def _triple_is_in_meta_map(self, kp: str, - subject_categories: Set[str], - predicates: Set[str], - object_categories: Set[str]) -> bool: + subject_categories: set[str], + predicates: set[str], + object_categories: set[str]) -> bool: """ Returns True if at least one possible triple exists in the KP's meta map. NOT meant to handle empty predicates; you should sub in "biolink:related_to" for QEdges without predicates before calling this method. @@ -227,9 +229,7 @@ def _triple_is_in_meta_map(self, kp: str, if not subject_categories: # any subject subject_categories = set(predicates_map.keys()) if not object_categories: # any object - object_set = set() - _ = [object_set.add(obj) for obj_dict in predicates_map.values() for obj in obj_dict.keys()] - object_categories = object_set + object_categories = {obj for obj_dict in predicates_map.values() for obj in obj_dict.keys()} # handle combinations of subject and objects using cross product qg_sub_obj_dict = defaultdict(lambda: set()) diff --git a/code/ARAX/ARAXQuery/Expand/trapi_querier.py b/code/ARAX/ARAXQuery/Expand/trapi_querier.py index 2769f4eeb..58006ad8c 100644 --- a/code/ARAX/ARAXQuery/Expand/trapi_querier.py +++ b/code/ARAX/ARAXQuery/Expand/trapi_querier.py @@ -10,7 +10,7 @@ import aiohttp import asyncio import requests -from typing import List, Dict, Set, Union, Optional, Tuple +from typing import Union, Optional, Any, cast import requests_cache @@ -21,7 +21,6 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../") # ARAXQuery directory from ARAX_response import ARAXResponse from ARAX_messenger import ARAXMessenger -from ARAX_query import ARAXQuery sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../../UI/OpenAPI/python-flask-server/") from openapi_server.models.node import Node from openapi_server.models.edge import Edge @@ -36,7 +35,7 @@ def _remove_attributes_with_invalid_values(response_json: dict, kp_curie: str, log: ARAXResponse) -> \ - tuple[dict, int]: + list[object]: r = response_json count_att_dropped = 0 for ekey, edge_obj in r['message']['knowledge_graph']['edges'].items(): @@ -65,8 +64,11 @@ def _remove_attributes_with_invalid_values(response_json: dict, class TRAPIQuerier: - def __init__(self, response_object: ARAXResponse, kp_name: str, user_specified_kp: bool, kp_timeout: Optional[int], - kp_selector: KPSelector = None): + def __init__(self, response_object: ARAXResponse, + kp_name: str, + user_specified_kp: bool, + kp_timeout: Optional[int], + kp_selector: Optional[KPSelector] = None): self.log = response_object self.kp_infores_curie = kp_name self.user_specified_kp = user_specified_kp @@ -75,7 +77,7 @@ def __init__(self, response_object: ARAXResponse, kp_name: str, user_specified_k kp_selector = KPSelector() self.kp_selector = kp_selector self.kp_endpoint = kp_selector.kp_urls[self.kp_infores_curie] - self.qnodes_with_single_id = dict() # This is set during the processing of each query + self.qnodes_with_single_id: dict[str, str] = dict() # This is set during the processing of each query self.arax_infores_curie = "infores:arax" self.arax_retrieval_source = RetrievalSource(resource_id=self.arax_infores_curie, resource_role="aggregator_knowledge_source", @@ -205,7 +207,7 @@ def _verify_is_single_node_query_graph(self, query_graph: QueryGraph): self.log.error(f"answer_single_node_query() was passed a query graph that has edges: " f"{query_graph.to_dict()}", error_code="InvalidQuery") - def _get_kg_to_qg_mappings_from_results(self, results: List[Result], qg: QueryGraph) -> Tuple[Dict[str, Dict[str, Set[str]]], Dict[str, Set[str]]]: + def _get_kg_to_qg_mappings_from_results(self, results: list[Result], qg: QueryGraph) -> tuple[dict[str, dict[str, set[str]]], dict[str, set[str]]]: """ This function returns a dictionary in which one can lookup which qnode_keys/qedge_keys a given node/edge fulfills. Like: {"nodes": {"PR:11": {"n00"}, "MESH:22": {"n00", "n01"} ... }, "edges": { ... }} @@ -302,12 +304,12 @@ async def _answer_query_using_kp_async(self, query_graph: QueryGraph) -> QGOrgan self.log.warning(f"{self.kp_infores_curie}: {exception_message}") self.log.update_query_plan(qedge_key, self.kp_infores_curie, "Error", exception_message) return QGOrganizedKnowledgeGraph() - wait_time = round(time.time() - start) json_response, cd = \ _remove_attributes_with_invalid_values(json_response, self.kp_infores_curie, self.log) + json_response = cast(dict[str, Any], json_response) answer_kg = self._load_kp_json_response(json_response, query_graph) num_edges = len(answer_kg.edges_by_qg_id.get(qedge_key, dict())) done_message = f"Returned {num_edges} edges in {wait_time} seconds" @@ -354,6 +356,7 @@ def _answer_query_using_kp(self, query_graph: QueryGraph) -> QGOrganizedKnowledg json_response, _ = _remove_attributes_with_invalid_values(json_response, self.kp_infores_curie, self.log) + json_response = cast(dict[str, Any], json_response) answer_kg = self._load_kp_json_response(json_response, query_graph) return answer_kg @@ -371,8 +374,8 @@ def _get_prepped_request_body(self, qg: QueryGraph) -> dict: # Load the query into a JSON Query object json_qg = {'nodes': stripped_qnodes, 'edges': stripped_qedges} - body = {'message': {'query_graph': json_qg}} - body['submitter'] = "infores:arax" + body: dict[str, Any] = {'message': {'query_graph': json_qg}, + 'submitter': 'infores:arax'} if self.kp_infores_curie == "infores:rtx-kg2": body['return_minimal_metadata'] = True # Don't want KG2 attributes because ARAX adds them later (faster) return body @@ -437,18 +440,16 @@ def _load_kp_json_response(self, json_response: dict, qg: QueryGraph) -> QGOrgan answer_kg.add_edge(arax_edge_key, returned_edge, qedge_key) else: returned_edge_keys_missing_qg_bindings.add(returned_edge_key) - if returned_edge_keys_missing_qg_bindings: - self.log.warning(f"{self.kp_infores_curie}: {len(returned_edge_keys_missing_qg_bindings)} edges in the KP's answer " - f"KG have no bindings to the QG: {returned_edge_keys_missing_qg_bindings}") # Populate our final KG with the returned nodes - returned_node_keys_missing_qg_bindings = set() + returned_node_keys_missing_qg_bindings = dict() + nodes_referenced_in_result_analysis_edges = set() for returned_node_key, returned_node in kp_message.knowledge_graph.nodes.items(): if not returned_node_key: self.log.warning(f"{self.kp_infores_curie}: Node has empty ID, skipping. Node key is: " f"'{returned_node_key}'") elif returned_node_key not in kg_to_qg_mappings['nodes']: - returned_node_keys_missing_qg_bindings.add(returned_node_key) + returned_node_keys_missing_qg_bindings[returned_node_key] = returned_node else: for qnode_key in kg_to_qg_mappings['nodes'][returned_node_key]: answer_kg.add_node(returned_node_key, returned_node, qnode_key) @@ -457,8 +458,37 @@ def _load_kp_json_response(self, json_response: dict, qg: QueryGraph) -> QGOrgan if not attribute.attribute_type_id: attribute.attribute_type_id = f"not provided (this attribute came from {self.kp_infores_curie})" if returned_node_keys_missing_qg_bindings: - self.log.warning(f"{self.kp_infores_curie}: {len(returned_node_keys_missing_qg_bindings)} nodes in the KP's answer " - f"KG have no bindings to the QG: {returned_node_keys_missing_qg_bindings}") + for result in kp_message.results: + for analysis in result.analyses: + for qedge_key, edge_bindings in analysis.edge_bindings.items(): + if qedge_key in qg.edges: + for edge_binding in edge_bindings: + edge_id = edge_binding.id + if edge_id in kp_message.knowledge_graph.edges: + edge = kp_message.knowledge_graph.edges[edge_id] + nodes_referenced_in_result_analysis_edges.add(edge.subject) + nodes_referenced_in_result_analysis_edges.add(edge.object) + + allowed_unbound_nodes = dict() + unreferenced_unbound_nodes = dict() + for node_key, node in returned_node_keys_missing_qg_bindings.items(): + if node_key in nodes_referenced_in_result_analysis_edges: + allowed_unbound_nodes[node_key] = node + else: + unreferenced_unbound_nodes[node_key] = node + answer_kg.unbound_nodes = allowed_unbound_nodes + allowed_unbound_edges = set() + if returned_edge_keys_missing_qg_bindings: + for aux_graph_id, aux_graph in kp_message.auxiliary_graphs.items(): + for edge_id in aux_graph.edges: + allowed_unbound_edges.add(edge_id) + returned_edge_keys_missing_qg_bindings -= allowed_unbound_edges + if returned_edge_keys_missing_qg_bindings: + self.log.warning(f"{self.kp_infores_curie}: {len(returned_edge_keys_missing_qg_bindings)} edges in the KP's answer " + f"KG have no bindings to the QG: {returned_edge_keys_missing_qg_bindings}") + if unreferenced_unbound_nodes: + self.log.warning(f"{self.kp_infores_curie}: {len(unreferenced_unbound_nodes)} nodes in the KP's answer " + f"KG have no bindings to the QG and are not referenced in any analysis: {set(unreferenced_unbound_nodes.keys())}") # Fill out our unofficial node.query_ids property for nodes in answer_kg.nodes_by_qg_id.values(): @@ -471,7 +501,7 @@ def _load_kp_json_response(self, json_response: dict, qg: QueryGraph) -> QGOrgan return answer_kg @staticmethod - def _strip_empty_properties(qnode_or_qedge: Union[QNode, QEdge]) -> Dict[str, any]: + def _strip_empty_properties(qnode_or_qedge: Union[QNode, QEdge]) -> dict[str, Any]: dict_version_of_object = qnode_or_qedge.to_dict() stripped_dict = {property_name: value for property_name, value in dict_version_of_object.items() if dict_version_of_object.get(property_name) not in [None, []]} diff --git a/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py b/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py index ab6c3df42..84ca192f0 100644 --- a/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py +++ b/code/ARAX/ARAXQuery/Infer/scripts/creativeCRG.py @@ -12,7 +12,6 @@ RTXindex = pathlist.index("RTX") sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery'])) from ARAX_response import ARAXResponse -from ARAX_query import ARAXQuery sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'UI', 'OpenAPI', 'python-flask-server'])) import openapi_server sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'NodeSynonymizer'])) diff --git a/code/ARAX/NodeSynonymizer/node_synonymizer.py b/code/ARAX/NodeSynonymizer/node_synonymizer.py index 35b74e535..f51e10ce6 100644 --- a/code/ARAX/NodeSynonymizer/node_synonymizer.py +++ b/code/ARAX/NodeSynonymizer/node_synonymizer.py @@ -19,7 +19,6 @@ from RTXConfiguration import RTXConfiguration sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'ARAX', 'ARAXQuery'])) -from ARAX_database_manager import ARAXDatabaseManager sys.path.append(os.path.sep.join([*pathlist[:(RTXindex + 1)], 'code', 'UI', 'OpenAPI', 'python-flask-server'])) from openapi_server.models.knowledge_graph import KnowledgeGraph