Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ code/ARAX/ARAXQuery/Infer/data/xCRG_data/*.npz

code/UI/OpenAPI/python-flask-server/openapi_server/openapi/openapi.json
code/UI/OpenAPI/specifications/export/KG2/*/openapi.json
**/cache_refresh.pid
74 changes: 44 additions & 30 deletions code/ARAX/ARAXQuery/ARAX_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import traceback
from collections import defaultdict
from typing import List, Dict, Tuple, Union, Set, Optional
from typing import Union, Optional, Any

sys.path.append(os.path.dirname(os.path.abspath(__file__))) # ARAXQuery directory
from ARAX_response import ARAXResponse
Expand All @@ -31,6 +31,7 @@
from openapi_server.models.retrieval_source import RetrievalSource
from Expand.trapi_querier import TRAPIQuerier

UNBOUND_NODES_KEY = "__UNBOUND__"

def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs)

Expand Down Expand Up @@ -147,7 +148,11 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
message.knowledge_graph = KnowledgeGraph(nodes=dict(), edges=dict())
log = response
# Fetch the list of all registered kps with compatible versions
kp_selector = KPSelector(log=log)
try:
kp_selector = KPSelector(log=log)
except ValueError as e:
response.error(str(e))
return response

# Save the original QG, if it hasn't already been saved in ARAXQuery (happens for DSL queries..)
if not hasattr(response, "original_query_graph"):
Expand Down Expand Up @@ -597,9 +602,9 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
return response

@staticmethod
def get_inferred_answers(inferred_qedge_keys: List[str],
def get_inferred_answers(inferred_qedge_keys: list[str],
query_graph: QueryGraph,
response: ARAXResponse) -> Tuple[ARAXResponse, QGOrganizedKnowledgeGraph]:
response: ARAXResponse) -> tuple[ARAXResponse, QGOrganizedKnowledgeGraph]:
# Send ARAXInfer any one-hop, "inferred", "treats" queries (temporary way of making creative mode work)
overarching_kg = QGOrganizedKnowledgeGraph()
if inferred_qedge_keys:
Expand Down Expand Up @@ -745,7 +750,7 @@ async def _expand_edge_async(self, edge_qg: QueryGraph,
kp_selector: KPSelector,
log: ARAXResponse,
multiple_kps: bool = False,
be_creative_treats: bool = False) -> Tuple[QGOrganizedKnowledgeGraph, ARAXResponse]:
be_creative_treats: bool = False) -> tuple[QGOrganizedKnowledgeGraph, ARAXResponse]:
# This function answers a single-edge (one-hop) query using the specified knowledge provider
qedge_key = next(qedge_key for qedge_key in edge_qg.edges)
log.info(f"Expanding qedge {qedge_key} using {kp_to_use}")
Expand Down Expand Up @@ -815,7 +820,7 @@ async def _expand_edge_async(self, edge_qg: QueryGraph,

@staticmethod
def _expand_node(qnode_key: str,
kps_to_use: List[str],
kps_to_use: list[str],
query_graph: QueryGraph,
user_specified_kp: bool,
kp_timeout: Optional[int],
Expand Down Expand Up @@ -901,8 +906,8 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log:
edges={qedge_key: dict() for qedge_key in answer_kg.edges_by_qg_id})
curie_mappings = dict()

# First deduplicate the nodes
for qnode_key, nodes in answer_kg.nodes_by_qg_id.items():
# First deduplicate the bound nodes
for qnode_key, nodes in {**answer_kg.nodes_by_qg_id, UNBOUND_NODES_KEY: answer_kg.unbound_nodes}.items():
# Load preferred curie info from NodeSynonymizer
log.debug(f"{kp_name}: Getting preferred curies for {qnode_key} nodes returned in this step")
canonicalized_nodes = eu.get_canonical_curies_dict(list(nodes), log) if nodes else dict()
Expand All @@ -927,11 +932,17 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log:
curie_mappings[node_key] = preferred_curie

# Add this node into our deduplicated KG as necessary
if preferred_curie not in deduplicated_kg.nodes_by_qg_id[qnode_key]:
node_key = preferred_curie
node.name = preferred_name
node.categories = preferred_categories
deduplicated_kg.add_node(node_key, node, qnode_key)
if qnode_key != UNBOUND_NODES_KEY:
if preferred_curie not in deduplicated_kg.nodes_by_qg_id[qnode_key]:
node_key = preferred_curie
node.name = preferred_name
node.categories = preferred_categories
deduplicated_kg.add_node(node_key, node, qnode_key)
else: # this is an unbound node
if preferred_curie not in deduplicated_kg.unbound_nodes:
node.name = preferred_name
node.categories = preferred_categories
deduplicated_kg.unbound_nodes[preferred_curie] = node

# Then update the edges to reflect changes made to the nodes
dropped_edge_count = dict()
Expand All @@ -957,7 +968,7 @@ def _deduplicate_nodes(answer_kg: QGOrganizedKnowledgeGraph, kp_name: str, log:
return deduplicated_kg, dropped_edge_count

@staticmethod
def _extract_query_subgraph(qedge_keys_to_expand: List[str], query_graph: QueryGraph, log: ARAXResponse) -> QueryGraph:
def _extract_query_subgraph(qedge_keys_to_expand: list[str], query_graph: QueryGraph, log: ARAXResponse) -> QueryGraph:
# This function extracts a sub-query graph containing the provided qedge IDs from a larger query graph
sub_query_graph = QueryGraph(nodes=dict(), edges=dict())

Expand Down Expand Up @@ -1059,7 +1070,7 @@ def _merge_answer_into_message_kg(answer_kg: QGOrganizedKnowledgeGraph, overarch

@staticmethod
def _store_kryptonite_edge_info(kryptonite_kg: QGOrganizedKnowledgeGraph, kryptonite_qedge_key: str, qg: QueryGraph,
encountered_kryptonite_edges_info: Dict[str, Dict[str, Set[str]]], log: ARAXResponse):
encountered_kryptonite_edges_info: dict[str, dict[str, set[str]]], log: ARAXResponse):
"""
This function adds the IDs of nodes found by expansion of the given kryptonite ("not") edge to the global
encountered_kryptonite_edges_info dictionary, which is organized by QG IDs. This allows Expand to "remember"
Expand Down Expand Up @@ -1087,7 +1098,7 @@ def _store_kryptonite_edge_info(kryptonite_kg: QGOrganizedKnowledgeGraph, krypto

@staticmethod
def _apply_any_kryptonite_edges(organized_kg: QGOrganizedKnowledgeGraph, full_query_graph: QueryGraph,
encountered_kryptonite_edges_info: Dict[str, Dict[str, Set[str]]], log):
encountered_kryptonite_edges_info: dict[str, dict[str, set[str]]], log):
"""
This function breaks any paths in the KG for which a "not" (exclude=True) condition has been met; the remains
of the broken paths not used in other paths in the KG are cleaned up during later pruning of dead ends. The
Expand Down Expand Up @@ -1148,16 +1159,19 @@ def _prune_kg(qnode_key_to_prune: str, prune_threshold: int, kg: QGOrganizedKnow
num_edges_in_kg = sum([len(edges) for edges in kg.edges_by_qg_id.values()])
overlay_fet = True if num_edges_in_kg < 100000 else False
# Use fisher exact test and the ranker to prune down answers for this qnode
intermediate_results_response = eu.create_results(qg_for_resultify, kg_copy, log,
rank_results=True, overlay_fet=overlay_fet,
intermediate_results_response = eu.create_results(qg_for_resultify,
kg_copy,
log,
rank_results=True,
overlay_fet=overlay_fet,
qnode_key_to_prune=qnode_key_to_prune)
log.debug(f"A total of {len(intermediate_results_response.envelope.message.results)} "
f"intermediate results were created/ranked")
if intermediate_results_response.status == "OK":
# Filter down so we only keep the top X nodes
results = intermediate_results_response.envelope.message.results
results.sort(key=lambda x: x.analyses[0].score, reverse=True)
kept_nodes = set()
kept_nodes: set[str] = set()
scores = []
counter = 0
while len(kept_nodes) < prune_threshold and counter < len(results):
Expand Down Expand Up @@ -1203,7 +1217,7 @@ def _remove_dead_end_paths(expands_qg: QueryGraph, kg: QGOrganizedKnowledgeGraph

@staticmethod
def _add_node_connection_to_map(qnode_key_a: str, qnode_key_b: str, edge: Edge,
node_connections_map: Dict[str, Dict[str, Dict[str, Set[str]]]]):
node_connections_map: dict[str, dict[str, dict[str, set[str]]]]):
# This is a helper function that's used when building a map of which nodes are connected to which
# Example node_connections_map: {'n01': {'UMLS:1222': {'n00': {'DOID:122'}, 'n02': {'UniProtKB:22'}}}, ...}
# Initiate entries for this edge's nodes as needed
Expand All @@ -1220,14 +1234,14 @@ def _add_node_connection_to_map(qnode_key_a: str, qnode_key_b: str, edge: Edge,
node_connections_map[qnode_key_a][edge.subject][qnode_key_b].add(edge.object)
node_connections_map[qnode_key_b][edge.object][qnode_key_a].add(edge.subject)

def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXResponse) -> List[str]:
def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXResponse) -> list[str]:
"""
This function determines what order to expand the edges in a query graph in; it aims to start with a required,
non-kryptonite qedge that has a qnode with a curie specified. It then looks for a qedge connected to that
starting qedge, and so on.
"""
qedge_keys_remaining = [qedge_key for qedge_key in query_graph.edges]
ordered_qedge_keys = []
ordered_qedge_keys: list[str] = []
while qedge_keys_remaining:
if not ordered_qedge_keys:
# Try to start with a required, non-kryptonite qedge that has a qnode with a curie specified
Expand Down Expand Up @@ -1258,7 +1272,7 @@ def _get_order_to_expand_qedges_in(self, query_graph: QueryGraph, log: ARAXRespo
return ordered_qedge_keys

@staticmethod
def _find_qedge_connected_to_subgraph(subgraph_qedge_keys: List[str], qedge_keys_to_choose_from: List[str],
def _find_qedge_connected_to_subgraph(subgraph_qedge_keys: list[str], qedge_keys_to_choose_from: list[str],
qg: QueryGraph) -> Optional[str]:
qnode_keys_in_subgraph = {qnode_key for qedge_key in subgraph_qedge_keys for qnode_key in
{qg.edges[qedge_key].subject, qg.edges[qedge_key].object}}
Expand Down Expand Up @@ -1322,9 +1336,9 @@ def _remove_self_edges(kg: QGOrganizedKnowledgeGraph, kp_name: str, log: ARAXRes
log.debug(f"{kp_name}: After removing self-edges, answer KG counts are: {eu.get_printable_counts_by_qg_id(kg)}")
return kg

def _set_and_validate_parameters(self, input_parameters: Dict[str, any], kp_selector: KPSelector,
log: ARAXResponse) -> Dict[str, any]:
parameters = dict()
def _set_and_validate_parameters(self, input_parameters: dict[str, Any], kp_selector: KPSelector,
log: ARAXResponse) -> dict[str, Any]:
parameters: dict[str, Any] = dict()
parameter_info_dict = self.get_parameter_info_dict()

# First set parameters to their defaults
Expand Down Expand Up @@ -1359,7 +1373,7 @@ def _set_and_validate_parameters(self, input_parameters: Dict[str, any], kp_sele
return parameters

@staticmethod
def is_supported_constraint(constraint: AttributeConstraint, supported_constraints_map: Dict[str, Dict[str, Set[str]]]) -> bool:
def is_supported_constraint(constraint: AttributeConstraint, supported_constraints_map: dict[str, dict[str, set[str]]]) -> bool:
if constraint.id not in supported_constraints_map:
return False
if constraint.operator not in supported_constraints_map[constraint.id]:
Expand All @@ -1371,7 +1385,7 @@ def is_supported_constraint(constraint: AttributeConstraint, supported_constrain
return True

@staticmethod
def _load_fda_approved_drug_ids() -> Set[str]:
def _load_fda_approved_drug_ids() -> set[str]:
# Determine the local path to the FDA-approved drugs pickle
path_list = os.path.realpath(__file__).split(os.path.sep)
rtx_index = path_list.index("RTX")
Expand Down Expand Up @@ -1467,12 +1481,12 @@ def _get_orphan_qnode_keys(query_graph: QueryGraph):
return list(all_qnode_keys.difference(qnode_keys_used_by_qedges))

@staticmethod
def _get_qedges_with_curie_qnode(query_graph: QueryGraph) -> List[str]:
def _get_qedges_with_curie_qnode(query_graph: QueryGraph) -> list[str]:
return [qedge_key for qedge_key, qedge in query_graph.edges.items()
if query_graph.nodes[qedge.subject].ids or query_graph.nodes[qedge.object].ids]

@staticmethod
def _find_connected_qedge(qedge_choices: List[QEdge], qedge: QEdge) -> QEdge:
def _find_connected_qedge(qedge_choices: list[QEdge], qedge: QEdge) -> QEdge:
qedge_qnode_keys = {qedge.subject, qedge.object}
connected_qedges = []
for other_qedge in qedge_choices:
Expand Down
4 changes: 3 additions & 1 deletion code/ARAX/ARAXQuery/ARAX_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def __init__(self, status='OK', logging_level=WARNING, error_code='OK', message=
self.n_warnings = 0
self.data = {}
self.envelope = None

self.query_plan = { 'qedge_keys': {}, 'counter': 0 }
self.wait_time = None # this attribute is set by trapi_querier.py
self.http_error = None # this attribute is set by trapi_querier.py
self.timed_out = None # this attribute is set by trapi_querier.py


#### Add a debugging message
Expand Down
Loading