From fa3e6ff9462ff45d9c9d0e819fc6192b42aa9533 Mon Sep 17 00:00:00 2001 From: maleicacid <4982384+kazuki0824@users.noreply.github.com> Date: Mon, 9 Dec 2024 01:19:53 +0900 Subject: [PATCH] Implement condensation tentatively Update mod.rs add wrapper Update rustworkx.pyi Update rustworkx.pyi --- rustworkx/__init__.pyi | 1 + rustworkx/rustworkx.pyi | 2 +- src/connectivity/mod.rs | 45 ++++++++++++++++++++++++++++------------- src/lib.rs | 1 + 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 5952e177e..d1698e969 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -71,6 +71,7 @@ from .rustworkx import number_connected_components as number_connected_component from .rustworkx import number_weakly_connected_components as number_weakly_connected_components from .rustworkx import node_connected_component as node_connected_component from .rustworkx import strongly_connected_components as strongly_connected_components +from .rustworkx import condensation as condensation from .rustworkx import weakly_connected_components as weakly_connected_components from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index b414d07de..ebbfc223a 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -192,7 +192,7 @@ def number_connected_components(graph: PyGraph, /) -> int: ... def number_weakly_connected_components(graph: PyDiGraph, /) -> bool: ... def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ... def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ... -def condensation(graph: PyDiGraph, /) -> list[list[int]]: ... +def condensation(graph: PyDiGraph, /, sccs=None) -> PyDiGraph: ... def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ... def digraph_adjacency_matrix( graph: PyDiGraph[_S, _T], diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 65e9160a1..a22e4036b 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -21,8 +21,8 @@ use super::{ }; use hashbrown::{HashMap, HashSet}; -use petgraph::algo; -use petgraph::graph::DiGraph; +use petgraph::{algo, Graph}; +use petgraph::graph::{DiGraph, IndexType}; use petgraph::stable_graph::NodeIndex; use petgraph::unionfind::UnionFind; use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable}; @@ -35,6 +35,7 @@ use rayon::prelude::*; use ndarray::prelude::*; use numpy::IntoPyArray; use petgraph::algo::kosaraju_scc; +use petgraph::prelude::StableGraph; use crate::iterators::{ AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices, @@ -114,18 +115,19 @@ pub fn strongly_connected_components(graph: &digraph::PyDiGraph) -> Vec>>) - -> digraph::PyDiGraph { - use petgraph::graph::NodeIndex; - use petgraph::{Directed, Graph}; - - let g = graph.graph; - - // TODO: Override sccs from arg +pub fn condensation_inner<'a, N, E, Ty, Ix>( + py: &'a Python, + g: Graph, + make_acyclic: bool, +) -> StableGraph + where + Ty: EdgeType, + Ix: IndexType, + N: ToPyObject, + E: ToPyObject +{ let sccs = kosaraju_scc(&g); - let mut condensed: Graph, E, Ty, Ix> = Graph::with_capacity(sccs.len(), g.edge_count()); + let mut condensed: StableGraph, E, Ty, Ix> = StableGraph::with_capacity(sccs.len(), g.edge_count()); // Build a map from old indices to new ones. let mut node_map = vec![NodeIndex::end(); g.node_count()]; @@ -152,12 +154,27 @@ pub fn condensation(py: Python, graph: &digraph::PyDiGraph, sccs: Option>>) + -> digraph::PyDiGraph { + let g = graph.graph.clone(); + + // TODO: Override sccs from arg + let condensed = if let Some(sccs) = sccs { + unimplemented!("") + } else { + condensation_inner(&py, g.into(), true) + }; // TODO: Fit for networkx let result = condensed; digraph::PyDiGraph { - graph: result.into(), + graph: result, cycle_state: algo::DfsSpace::default(), check_cycle: false, node_removed: false, diff --git a/src/lib.rs b/src/lib.rs index 4ee4189a7..21661b4e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -569,6 +569,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cycle_basis))?; m.add_wrapped(wrap_pyfunction!(simple_cycles))?; m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?; + m.add_wrapped(wrap_pyfunction!(condensation))?; m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?; m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?; m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?;