Skip to content

Commit

Permalink
Implement condensation tentatively
Browse files Browse the repository at this point in the history
Update mod.rs

add wrapper

Update rustworkx.pyi

Update rustworkx.pyi
  • Loading branch information
kazuki0824 committed Dec 8, 2024
1 parent 11aac0d commit fa3e6ff
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ from .rustworkx import number_connected_components as number_connected_component
from .rustworkx import number_weakly_connected_components as number_weakly_connected_components
from .rustworkx import node_connected_component as node_connected_component
from .rustworkx import strongly_connected_components as strongly_connected_components
from .rustworkx import condensation as condensation
from .rustworkx import weakly_connected_components as weakly_connected_components
from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix
from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix
Expand Down
2 changes: 1 addition & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def number_connected_components(graph: PyGraph, /) -> int: ...
def number_weakly_connected_components(graph: PyDiGraph, /) -> bool: ...
def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ...
def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ...
def condensation(graph: PyDiGraph, /) -> list[list[int]]: ...
def condensation(graph: PyDiGraph, /, sccs=None) -> PyDiGraph: ...
def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ...
def digraph_adjacency_matrix(
graph: PyDiGraph[_S, _T],
Expand Down
45 changes: 31 additions & 14 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use super::{
};

use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::graph::DiGraph;
use petgraph::{algo, Graph};
use petgraph::graph::{DiGraph, IndexType};
use petgraph::stable_graph::NodeIndex;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable};
Expand All @@ -35,6 +35,7 @@ use rayon::prelude::*;
use ndarray::prelude::*;
use numpy::IntoPyArray;
use petgraph::algo::kosaraju_scc;
use petgraph::prelude::StableGraph;

use crate::iterators::{
AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices,
Expand Down Expand Up @@ -114,18 +115,19 @@ pub fn strongly_connected_components(graph: &digraph::PyDiGraph) -> Vec<Vec<usiz
.collect()
}

#[pyfunction]
#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, /, sccs=None))]
pub fn condensation(py: Python, graph: &digraph::PyDiGraph, sccs: Option<Vec<Vec<usize>>>)
-> digraph::PyDiGraph {
use petgraph::graph::NodeIndex;
use petgraph::{Directed, Graph};

let g = graph.graph;

// TODO: Override sccs from arg
pub fn condensation_inner<'a, N, E, Ty, Ix>(
py: &'a Python,
g: Graph<N, E, Ty, Ix>,
make_acyclic: bool,
) -> StableGraph<PyObject, PyObject, Ty, Ix>
where
Ty: EdgeType,
Ix: IndexType,
N: ToPyObject,
E: ToPyObject
{
let sccs = kosaraju_scc(&g);
let mut condensed: Graph<Vec<N>, E, Ty, Ix> = Graph::with_capacity(sccs.len(), g.edge_count());
let mut condensed: StableGraph<Vec<N>, E, Ty, Ix> = StableGraph::with_capacity(sccs.len(), g.edge_count());

// Build a map from old indices to new ones.
let mut node_map = vec![NodeIndex::end(); g.node_count()];
Expand All @@ -152,12 +154,27 @@ pub fn condensation(py: Python, graph: &digraph::PyDiGraph, sccs: Option<Vec<Vec
condensed.add_edge(source, target, edge.weight);
}
}
condensed.map(|_, w| w.to_object(*py), |_,w| w.to_object(*py))
}

#[pyfunction]
#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, sccs=None))]
pub fn condensation(py: Python, graph: &digraph::PyDiGraph, sccs: Option<Vec<Vec<usize>>>)
-> digraph::PyDiGraph {
let g = graph.graph.clone();

// TODO: Override sccs from arg
let condensed = if let Some(sccs) = sccs {
unimplemented!("")
} else {
condensation_inner(&py, g.into(), true)
};

// TODO: Fit for networkx
let result = condensed;

digraph::PyDiGraph {
graph: result.into(),
graph: result,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cycle_basis))?;
m.add_wrapped(wrap_pyfunction!(simple_cycles))?;
m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?;
m.add_wrapped(wrap_pyfunction!(condensation))?;
m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?;
Expand Down

0 comments on commit fa3e6ff

Please sign in to comment.