diff --git a/src/gocam/translation/cx2/main.py b/src/gocam/translation/cx2/main.py index 3f93cef..f329099 100644 --- a/src/gocam/translation/cx2/main.py +++ b/src/gocam/translation/cx2/main.py @@ -66,7 +66,8 @@ def _get_context(): def model_to_cx2(gocam: Model) -> list: # Internal state input_output_nodes: Dict[str, int] = {} - activity_nodes: Dict[str, int] = {} + activity_nodes_by_activity_id: Dict[str, int] = {} + activity_nodes_by_enabled_by_id: Dict[str, int] = {} # Internal helper functions that access internal state def _get_object_label(object_id: str) -> str: @@ -82,7 +83,11 @@ def _add_input_output_nodes( if not isinstance(associations, list): associations = [associations] for association in associations: - if association.term not in input_output_nodes: + if association.term in activity_nodes_by_enabled_by_id: + target = activity_nodes_by_enabled_by_id[association.term] + elif association.term in input_output_nodes: + target = input_output_nodes[association.term] + else: node_attributes = { "name": _get_object_label(association.term), "represents": association.term, @@ -94,13 +99,12 @@ def _add_input_output_nodes( p.contributor for p in association.provenances ] - input_output_nodes[association.term] = cx2_network.add_node( - attributes=node_attributes - ) + target = cx2_network.add_node(attributes=node_attributes) + input_output_nodes[association.term] = target cx2_network.add_edge( - source=activity_nodes[activity.id], - target=input_output_nodes[association.term], + source=activity_nodes_by_activity_id[activity.id], + target=target, attributes=edge_attributes, ) @@ -174,7 +178,9 @@ def _add_input_output_nodes( p.contributor for p in activity.provenances ] - activity_nodes[activity.id] = cx2_network.add_node(attributes=node_attributes) + node = cx2_network.add_node(attributes=node_attributes) + activity_nodes_by_activity_id[activity.id] = node + activity_nodes_by_enabled_by_id[activity.enabled_by.term] = node # Add nodes for input/output molecules and create edges to activity nodes for activity in gocam.activities: @@ -196,7 +202,7 @@ def _add_input_output_nodes( # Add edges for causal associations between activity nodes for activity in gocam.activities: for association in activity.causal_associations: - if association.downstream_activity in activity_nodes: + if association.downstream_activity in activity_nodes_by_activity_id: relation_style = RELATIONS.get(association.predicate, None) if relation_style is None: logger.warning( @@ -221,8 +227,10 @@ def _add_input_output_nodes( ] cx2_network.add_edge( - source=activity_nodes[activity.id], - target=activity_nodes[association.downstream_activity], + source=activity_nodes_by_activity_id[activity.id], + target=activity_nodes_by_activity_id[ + association.downstream_activity + ], attributes=edge_attributes, )