From 7ed0f1ff2e825891ebb9d91ff04c071a7002f790 Mon Sep 17 00:00:00 2001 From: Edwin Navarro Date: Tue, 18 Apr 2023 04:48:07 -0700 Subject: [PATCH 01/37] Add token swapper to rustworkx-core (#765) * Add token swapper * Finish add_token_edges setup rand * Finish main fn * Add swap fn * Add find_cycle and start trial_map * Finish good compile code * First testing * Rebuild swapper using struct and impl * Restructure around struct and token_swapper * Fix find_cycle and passing copies * First test success and cleanup * More testing and simplify di and sub_di graph * Finish tests * Format * Lint * Add python interface * Fix limit calc * Fix pyo3 signature * One more time * pyo3 fix * Fix signature 2 * Add python tests * Cleanup and reno * Fix conflict * Use connectivity::find_cycle * Review updates * Switch to map * Convert trial_map to parallel * Fix seed changes * Cleanup and comments * Name changes * Finish docs and add threshold * Fix stable graph removed nodes and error msgs --------- Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- Cargo.lock | 1 + docs/source/api.rst | 1 + .../added-token-swapper-bd168eeb5a31bd99.yaml | 6 + rustworkx-core/Cargo.toml | 2 + rustworkx-core/src/lib.rs | 3 + rustworkx-core/src/token_swapper.rs | 608 ++++++++++++++++++ src/lib.rs | 3 + src/token_swapper.rs | 69 ++ tests/rustworkx_tests/test_token_swapper.py | 118 ++++ 9 files changed, 811 insertions(+) create mode 100644 releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml create mode 100644 rustworkx-core/src/token_swapper.rs create mode 100644 src/token_swapper.rs create mode 100644 tests/rustworkx_tests/test_token_swapper.py diff --git a/Cargo.lock b/Cargo.lock index 8eee8253d..6e4267fc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -562,6 +562,7 @@ dependencies = [ "petgraph", "priority-queue", "rand", + "rand_pcg", "rayon", "rayon-cond", ] diff --git a/docs/source/api.rst b/docs/source/api.rst index 487d543b8..44a7dab7a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -370,6 +370,7 @@ typed API based on the data type. rustworkx.graph_complement rustworkx.graph_union rustworkx.graph_tensor_product + rustworkx.graph_token_swapper rustworkx.graph_cartesian_product rustworkx.graph_random_layout rustworkx.graph_bipartite_layout diff --git a/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml b/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml new file mode 100644 index 000000000..e26ec104e --- /dev/null +++ b/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function, :func:`~.token_swapper()` which performs an + approximately optimal Token Swapping algorithm and supports partial + mappings (i.e. not-permutations) for graphs with missing tokens. diff --git a/rustworkx-core/Cargo.toml b/rustworkx-core/Cargo.toml index 7a3762a02..33c7e1200 100644 --- a/rustworkx-core/Cargo.toml +++ b/rustworkx-core/Cargo.toml @@ -14,6 +14,8 @@ keywords = ["graph"] ahash = "0.8.0" fixedbitset = "0.4.2" petgraph = "0.6.3" +rand = "0.8.5" +rand_pcg = "0.3.1" rayon = "1.6" num-traits = "0.2" priority-queue = "1.2" diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 4da244f5c..ab54ad5dc 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -50,6 +50,7 @@ //! * [`connectivity`](./connectivity/index.html) //! * [`max_weight_matching`](./max_weight_matching/index.html) //! * [`shortest_path`](./shortest_path/index.html) +//! * [`token_swapper`](./token_swapper/index.html) //! * [`traversal`](./traversal/index.html) //! * [`generators`](./generators/index.html) //! @@ -82,6 +83,8 @@ pub mod traversal; pub mod dictmap; pub mod distancemap; mod min_scored; +/// Module for swapping tokens +pub mod token_swapper; pub mod utils; // re-export petgraph so there is a consistent version available to users and diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs new file mode 100644 index 000000000..469236acc --- /dev/null +++ b/rustworkx-core/src/token_swapper.rs @@ -0,0 +1,608 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use rand::distributions::{Standard, Uniform}; +use rand::prelude::*; +use rand_pcg::Pcg64; +use std::hash::Hash; + +use hashbrown::HashMap; +use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::{ + EdgeCount, GraphBase, IntoEdges, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, + NodeIndexable, Visitable, +}; +use petgraph::Directed; +use petgraph::Direction::{Incoming, Outgoing}; +use rayon_cond::CondIterator; + +use crate::connectivity::find_cycle; +use crate::dictmap::*; +use crate::shortest_path::dijkstra; +use crate::traversal::dfs_edges; + +type Swap = (NodeIndex, NodeIndex); +type Edge = (NodeIndex, NodeIndex); + +struct TokenSwapper +where + G::NodeId: Eq + Hash, +{ + // The input graph + graph: G, + // The user-supplied mapping to use for swapping tokens + mapping: HashMap, + // Number of trials + trials: usize, + // Seed for random selection of a node for a trial + seed: Option, + // Threshold for how many nodes will trigger parallel iterator + parallel_threshold: usize, + // Map of NodeId to NodeIndex + node_map: HashMap, + // Map of NodeIndex to NodeId + rev_node_map: HashMap, +} + +impl TokenSwapper +where + G: NodeCount + + EdgeCount + + IntoEdges + + Visitable + + NodeIndexable + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Send + + Sync, + G::NodeId: Hash + Eq + Send + Sync, +{ + fn new( + graph: G, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, + ) -> Self { + TokenSwapper { + graph, + mapping, + trials: trials.unwrap_or(4), + seed, + parallel_threshold: parallel_threshold.unwrap_or(50), + node_map: HashMap::with_capacity(graph.node_count()), + rev_node_map: HashMap::with_capacity(graph.node_count()), + } + } + + fn map(&mut self) -> Vec { + let num_nodes = self.graph.node_bound(); + let num_edges = self.graph.edge_count(); + + // Directed graph with nodes matching ``graph`` and + // edges for neighbors closer than nodes + let mut digraph = StableGraph::with_capacity(num_nodes, num_edges); + + // First fill the digraph with nodes. Then since it's a stable graph, + // must go through and remove nodes that were removed in original graph + for _ in 0..self.graph.node_bound() { + digraph.add_node(()); + } + let mut count: usize = 0; + for gnode in self.graph.node_identifiers() { + let gidx = self.graph.to_index(gnode); + if gidx != count { + for idx in count..gidx { + digraph.remove_node(NodeIndex::new(idx)); + } + count = gidx; + } + count += 1; + } + + // Create maps between NodeId and NodeIndex + for node in self.graph.node_identifiers() { + self.node_map + .insert(node, NodeIndex::new(self.graph.to_index(node))); + self.rev_node_map + .insert(NodeIndex::new(self.graph.to_index(node)), node); + } + // sub will become same as digraph but with no self edges in add_token_edges + let mut sub_digraph = digraph.clone(); + + // The mapping in HashMap form using NodeIndex + let mut tokens: HashMap = self + .mapping + .iter() + .map(|(k, v)| (self.node_map[k], self.node_map[v])) + .collect(); + + // todo_nodes are all the mapping entries where left != right + let todo_nodes: Vec = tokens + .iter() + .filter_map(|(node, dest)| if node != dest { Some(*node) } else { None }) + .collect(); + + // Add initial edges to the digraph/sub_digraph + for node in self.graph.node_identifiers() { + self.add_token_edges( + self.node_map[&node], + &mut digraph, + &mut sub_digraph, + &mut tokens, + ); + } + // First collect the self.trial number of random numbers + // into a Vec based on the given seed + let outer_rng: Pcg64 = match self.seed { + Some(rng_seed) => Pcg64::seed_from_u64(rng_seed), + None => Pcg64::from_entropy(), + }; + let trial_seeds_vec: Vec = + outer_rng.sample_iter(&Standard).take(self.trials).collect(); + + CondIterator::new( + trial_seeds_vec, + self.graph.node_count() >= self.parallel_threshold, + ) + .map(|trial_seed| { + self.trial_map( + digraph.clone(), + sub_digraph.clone(), + tokens.clone(), + todo_nodes.clone(), + trial_seed, + ) + }) + .min_by_key(|result| result.len()) + .unwrap() + } + + fn add_token_edges( + &self, + node: NodeIndex, + digraph: &mut StableGraph<(), (), Directed>, + sub_digraph: &mut StableGraph<(), (), Directed>, + tokens: &mut HashMap, + ) { + // Adds an edge to digraph if distance from the token to a neighbor is + // less than distance from token to node. sub_digraph is same except + // for self-edges. + if !(tokens.contains_key(&node)) { + return; + } + if tokens[&node] == node { + digraph.update_edge(node, node, ()); + return; + } + let id_node = self.rev_node_map[&node]; + let id_token = self.rev_node_map[&tokens[&node]]; + for id_neighbor in self.graph.neighbors(id_node) { + let neighbor = self.node_map[&id_neighbor]; + let dist_neighbor: DictMap = dijkstra( + &self.graph, + id_neighbor, + Some(id_token), + |_| Ok::(1), + None, + ) + .unwrap(); + + let dist_node: DictMap = dijkstra( + &self.graph, + id_node, + Some(id_token), + |_| Ok::(1), + None, + ) + .unwrap(); + + if dist_neighbor[&id_token] < dist_node[&id_token] { + digraph.update_edge(node, neighbor, ()); + sub_digraph.update_edge(node, neighbor, ()); + } + } + } + + fn trial_map( + &self, + mut digraph: StableGraph<(), (), Directed>, + mut sub_digraph: StableGraph<(), (), Directed>, + mut tokens: HashMap, + mut todo_nodes: Vec, + trial_seed: u64, + ) -> Vec { + // Create a random trial list of swaps to move tokens to optimal positions + let mut steps = 0; + let mut swap_edges: Vec = vec![]; + let mut rng_seed: Pcg64 = Pcg64::seed_from_u64(trial_seed); + while !todo_nodes.is_empty() && steps <= 4 * digraph.node_count().pow(2) { + // Choose a random todo_node + let between = Uniform::new(0, todo_nodes.len()); + let random: usize = between.sample(&mut rng_seed); + let todo_node = todo_nodes[random]; + + // If there's a cycle in sub_digraph, add it to swap_edges and do swap + let cycle = find_cycle(&sub_digraph, Some(todo_node)); + if !cycle.is_empty() { + for edge in cycle[1..].iter().rev() { + swap_edges.push(*edge); + self.swap( + edge.0, + edge.1, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + } + steps += cycle.len() - 1; + // If there's no cycle, see if there's an edge target that matches a token key. + // If so, add to swap_edges and do swap + } else { + let mut found = false; + let sub2 = &sub_digraph.clone(); + for edge in dfs_edges(sub2, Some(todo_node)) { + let new_edge = (NodeIndex::new(edge.0), NodeIndex::new(edge.1)); + if !tokens.contains_key(&new_edge.1) { + swap_edges.push(new_edge); + self.swap( + new_edge.0, + new_edge.1, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + steps += 1; + found = true; + break; + } + } + // If none found, look for cycle in digraph which will result in + // an unhappy node. Look for a predecessor and add node and pred + // to swap_edges and do swap + if !found { + let cycle: Vec = find_cycle(&digraph, Some(todo_node)); + let unhappy_node = cycle[0].0; + let mut found = false; + let di2 = &mut digraph.clone(); + for predecessor in di2.neighbors_directed(unhappy_node, Incoming) { + if predecessor != unhappy_node { + swap_edges.push((unhappy_node, predecessor)); + self.swap( + unhappy_node, + predecessor, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + steps += 1; + found = true; + break; + } + } + assert!( + found, + "The token swap process has ended unexpectedly, this points to a bug in rustworkx, please open an issue." + ); + } + } + } + assert!( + todo_nodes.is_empty(), + "The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue." + ); + swap_edges + } + + fn swap( + &self, + node1: NodeIndex, + node2: NodeIndex, + digraph: &mut StableGraph<(), (), Directed>, + sub_digraph: &mut StableGraph<(), (), Directed>, + tokens: &mut HashMap, + todo_nodes: &mut Vec, + ) { + // Get token values for the 2 nodes and remove them + let token1 = tokens.remove(&node1); + let token2 = tokens.remove(&node2); + + // Swap the token edge values + if let Some(t2) = token2 { + tokens.insert(node1, t2); + } + if let Some(t1) = token1 { + tokens.insert(node2, t1); + } + // For each node, remove the (node, successor) from digraph and + // sub_digraph. Then add new token edges back in. + for node in [node1, node2] { + let edge_nodes: Vec<(NodeIndex, NodeIndex)> = digraph + .neighbors_directed(node, Outgoing) + .map(|successor| (node, successor)) + .collect(); + for (edge_node1, edge_node2) in edge_nodes { + let edge = digraph.find_edge(edge_node1, edge_node2).unwrap(); + digraph.remove_edge(edge); + } + let edge_nodes: Vec<(NodeIndex, NodeIndex)> = sub_digraph + .neighbors_directed(node, Outgoing) + .map(|successor| (node, successor)) + .collect(); + for (edge_node1, edge_node2) in edge_nodes { + let edge = sub_digraph.find_edge(edge_node1, edge_node2).unwrap(); + sub_digraph.remove_edge(edge); + } + self.add_token_edges(node, digraph, sub_digraph, tokens); + + // If a node is a token key and not equal to the value, add it to todo_nodes + if tokens.contains_key(&node) && tokens[&node] != node { + if !todo_nodes.contains(&node) { + todo_nodes.push(node); + } + // Otherwise if node is in todo_nodes, remove it + } else if todo_nodes.contains(&node) { + todo_nodes.swap_remove(todo_nodes.iter().position(|x| *x == node).unwrap()); + } + } + } +} + +/// Module to perform an approximately optimal Token Swapping algorithm. Supports partial +/// mappings (i.e. not-permutations) for graphs with missing tokens. +/// +/// Based on the paper: Approximation and Hardness for Token Swapping by Miltzow et al. (2016) +/// ArXiV: +/// +/// Arguments: +/// +/// * `graph` - The graph on which to perform the token swapping. +/// * `mapping` - A partial mapping to be implemented in swaps. +/// * `trials` - Optional number of trials. If None, defaults to 4. +/// * `seed` - Optional integer seed. If None, the internal rng will be initialized from system entropy. +/// * `parallel_threshold` - Optional integer for the number of nodes in the graph that will +/// trigger the use of parallel threads. If the number of nodes in the graph is less than this value +/// it will run in a single thread. The default value is 50. +/// +/// It returns a list of tuples representing the swaps to perform. +/// +/// This function is multithreaded and will launch a thread pool with threads equal to +/// the number of CPUs by default. You can tune the number of threads with +/// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4`` +/// would limit the thread pool to 4 threads. +/// +/// # Example +/// ```rust +/// use hashbrown::HashMap; +/// use rustworkx_core::petgraph; +/// use rustworkx_core::token_swapper::token_swapper; +/// use rustworkx_core::petgraph::graph::NodeIndex; +/// +/// let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); +/// let mapping = HashMap::from([ +/// (NodeIndex::new(0), NodeIndex::new(0)), +/// (NodeIndex::new(1), NodeIndex::new(3)), +/// (NodeIndex::new(3), NodeIndex::new(1)), +/// (NodeIndex::new(2), NodeIndex::new(2)), +/// ]); +/// // Do the token swap +/// let output = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); +/// assert_eq!(3, output.len()); +/// +/// ``` + +pub fn token_swapper( + graph: G, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, +) -> Vec +where + G: NodeCount + + EdgeCount + + IntoEdges + + Visitable + + NodeIndexable + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Send + + Sync, + G::NodeId: Hash + Eq + Send + Sync, +{ + let mut swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold); + swapper.map() +} + +#[cfg(test)] +mod test_token_swapper { + + use crate::petgraph; + use crate::token_swapper::token_swapper; + use hashbrown::HashMap; + use petgraph::graph::NodeIndex; + + fn do_swap(mapping: &mut HashMap, swaps: &Vec<(NodeIndex, NodeIndex)>) { + // Apply the swaps to the mapping to get final result + for (swap1, swap2) in swaps { + //Need to create temp nodes in case of partial mapping + let mut temp_node1: Option = None; + let mut temp_node2: Option = None; + if mapping.contains_key(swap1) { + temp_node1 = Some(mapping[swap1]); + mapping.remove(swap1); + } + if mapping.contains_key(swap2) { + temp_node2 = Some(mapping[swap2]); + mapping.remove(swap2); + } + if let Some(t1) = temp_node1 { + mapping.insert(*swap2, t1); + } + if let Some(t2) = temp_node2 { + mapping.insert(*swap1, t2); + } + } + } + + #[test] + fn test_simple_swap() { + // Simple arbitrary swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(0)), + (NodeIndex::new(1), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(1)), + (NodeIndex::new(2), NodeIndex::new(2)), + ]); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + assert_eq!(3, swaps.len()); + } + + #[test] + fn test_small_swap() { + // Reverse all small swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7), + ]); + let mut mapping = HashMap::with_capacity(8); + for i in 0..8 { + mapping.insert(NodeIndex::new(i), NodeIndex::new(7 - i)); + } + // Do the token swap + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(8); + for i in 0..8 { + expected.insert(NodeIndex::new(i), NodeIndex::new(i)); + } + assert_eq!(expected, new_map); + } + + #[test] + fn test_happy_swap_chain() { + // Reverse all happy swap chain > 2 + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[ + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + (3, 4), + (3, 6), + ]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(4)), + (NodeIndex::new(1), NodeIndex::new(0)), + (NodeIndex::new(2), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(6)), + (NodeIndex::new(4), NodeIndex::new(2)), + (NodeIndex::new(6), NodeIndex::new(1)), + ]); + // Do the token swap + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(6); + for i in (0..5).chain(6..7) { + expected.insert(NodeIndex::new(i), NodeIndex::new(i)); + } + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_simple() { + // Simple partial swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([(NodeIndex::new(0), NodeIndex::new(3))]); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(4); + expected.insert(NodeIndex::new(3), NodeIndex::new(3)); + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_simple_remove_node() { + // Simple partial swap + let mut g = + petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4)]); + let mapping = HashMap::from([(NodeIndex::new(0), NodeIndex::new(3))]); + g.remove_node(NodeIndex::new(2)); + g.add_edge(NodeIndex::new(1), NodeIndex::new(3), ()); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(4); + expected.insert(NodeIndex::new(3), NodeIndex::new(3)); + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_small() { + // Partial inverting on small path graph + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(3)), + (NodeIndex::new(1), NodeIndex::new(2)), + ]); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let expected = HashMap::from([ + (NodeIndex::new(2), NodeIndex::new(2)), + (NodeIndex::new(3), NodeIndex::new(3)), + ]); + assert_eq!(5, swaps.len()); + assert_eq!(expected, new_map); + } +} + +// TODO: Port this test when rustworkx-core adds random graphs + +// def test_large_partial_random(self) -> None: +// """Test a random (partial) mapping on a large randomly generated graph""" +// size = 100 +// # Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. +// graph = rx.undirected_gnm_random_graph(size, size**2 // 10) +// for i in graph.node_indexes(): +// try: +// graph.remove_edge(i, i) # Remove self-loops. +// except rx.NoEdgeBetweenNodes: +// continue +// # Make sure the graph is connected by adding C_n +// graph.add_edges_from_no_data([(i, i + 1) for i in range(len(graph) - 1)]) +// swapper = ApproximateTokenSwapper(graph) # type: ApproximateTokenSwapper[int] + +// # Generate a randomized permutation. +// rand_perm = random.permutation(graph.nodes()) +// permutation = dict(zip(graph.nodes(), rand_perm)) +// mapping = dict(itertools.islice(permutation.items(), 0, size, 2)) # Drop every 2nd element. + +// out = list(swapper.map(mapping, trials=40)) +// util.swap_permutation([out], mapping, allow_missing_keys=True) +// self.assertEqual({i: i for i in mapping.values()}, mapping) diff --git a/src/lib.rs b/src/lib.rs index 87d6eea3c..7f941e2e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,7 @@ mod score; mod shortest_path; mod steiner_tree; mod tensor_product; +mod token_swapper; mod toposort; mod transitivity; mod traversal; @@ -52,6 +53,7 @@ use random_graph::*; use shortest_path::*; use steiner_tree::*; use tensor_product::*; +use token_swapper::*; use transitivity::*; use traversal::*; use tree::*; @@ -446,6 +448,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; + m.add_wrapped(wrap_pyfunction!(graph_token_swapper))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; m.add_wrapped(wrap_pyfunction!(digraph_core_number))?; m.add_wrapped(wrap_pyfunction!(graph_complement))?; diff --git a/src/token_swapper.rs b/src/token_swapper.rs new file mode 100644 index 000000000..0adf5b85b --- /dev/null +++ b/src/token_swapper.rs @@ -0,0 +1,69 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use crate::graph; +use crate::iterators::EdgeList; + +use hashbrown::HashMap; +use petgraph::graph::NodeIndex; +use pyo3::prelude::*; +use rustworkx_core::token_swapper; + +/// This module performs an approximately optimal Token Swapping algorithm +/// Supports partial mappings (i.e. not-permutations) for graphs with missing tokens. +/// +/// Based on the paper: Approximation and Hardness for Token Swapping by Miltzow et al. (2016) +/// ArXiV: https://arxiv.org/abs/1602.05150 +/// +/// The inputs are a partial ``mapping`` to be implemented in swaps, and the number of ``trials`` +/// to perform the mapping. It's minimized over the trials. +/// +/// It returns a list of tuples representing the swaps to perform. +/// +/// :param PyGraph graph: The input graph +/// :param dict[int: int] mapping: Map of (node, token) +/// :param int trials: The number of trials to run +/// :param int seed: The random seed to be used in producing random ints for selecting +/// which nodes to process next +/// :param int parallel_threshold: The number of nodes in the graph that will +/// trigger the use of parallel threads. If the number of nodes in the graph is less +/// than this value it will run in a single thread. The default value is 50. +/// +/// This function is multithreaded and will launch a thread pool with threads equal to +/// the number of CPUs by default. You can tune the number of threads with +/// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4`` +/// would limit the thread pool to 4 threads. +/// +/// :returns: A list of tuples which are the swaps to be applied to the mapping to rearrange +/// the tokens. +/// :rtype: EdgeList +#[pyfunction] +#[pyo3(text_signature = "(graph, mapping, /, trials=None, seed=None, parallel_threshold=50)")] +pub fn graph_token_swapper( + graph: &graph::PyGraph, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, +) -> EdgeList { + let map: HashMap = mapping + .iter() + .map(|(s, t)| (NodeIndex::new(*s), NodeIndex::new(*t))) + .collect(); + let swaps = token_swapper::token_swapper(&graph.graph, map, trials, seed, parallel_threshold); + EdgeList { + edges: swaps + .into_iter() + .map(|(s, t)| (s.index(), t.index())) + .collect(), + } +} diff --git a/tests/rustworkx_tests/test_token_swapper.py b/tests/rustworkx_tests/test_token_swapper.py new file mode 100644 index 000000000..b5a207e32 --- /dev/null +++ b/tests/rustworkx_tests/test_token_swapper.py @@ -0,0 +1,118 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import itertools +import rustworkx as rx + +from numpy import random + + +def swap_permutation( + mapping, + swaps, +) -> None: + for (sw1, sw2) in list(swaps): + val1 = mapping.pop(sw1, None) + val2 = mapping.pop(sw2, None) + + if val1 is not None: + mapping[sw2] = val1 + if val2 is not None: + mapping[sw1] = val2 + + +class TestGeneral(unittest.TestCase): + """The test cases""" + + def setUp(self) -> None: + """Set up test cases.""" + super().setUp() + random.seed(0) + + def test_simple(self) -> None: + """Test a simple permutation on a path graph of size 4.""" + graph = rx.generators.path_graph(4) + permutation = {0: 0, 1: 3, 3: 1, 2: 2} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({i: i for i in range(4)}, permutation) + + def test_small(self) -> None: + """Test an inverting permutation on a small path graph of size 8""" + graph = rx.generators.path_graph(8) + permutation = {i: 7 - i for i in range(8)} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual({i: i for i in range(8)}, permutation) + + def test_bug1(self) -> None: + """Tests for a bug that occured in happy swap chains of length >2.""" + graph = rx.PyGraph() + graph.extend_from_edge_list( + [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (3, 6)] + ) + permutation = {0: 4, 1: 0, 2: 3, 3: 6, 4: 2, 6: 1} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual({i: i for i in permutation}, permutation) + + def test_partial_simple(self) -> None: + """Test a partial mapping on a small graph.""" + graph = rx.generators.path_graph(4) + mapping = {0: 3} + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 10) + swap_permutation(mapping, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({3: 3}, mapping) + + def test_partial_simple_remove_node(self) -> None: + """Test a partial mapping on a small graph with a node removed.""" + graph = rx.generators.path_graph(5) + graph.remove_node(2) + graph.add_edge(1, 3, None) + mapping = {0: 3} + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 10) + swap_permutation(mapping, swaps) + self.assertEqual(2, len(swaps)) + self.assertEqual({3: 3}, mapping) + + def test_partial_small(self) -> None: + """Test an partial inverting permutation on a small path graph of size 5""" + graph = rx.generators.path_graph(4) + permutation = {i: 3 - i for i in range(2)} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 10) + swap_permutation(permutation, swaps) + self.assertEqual(5, len(swaps)) + self.assertEqual({i: i for i in permutation.values()}, permutation) + + def test_large_partial_random(self) -> None: + """Test a random (partial) mapping on a large randomly generated graph""" + size = 100 + # Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. + graph = rx.undirected_gnm_random_graph(size, size**2 // 10) + for i in graph.node_indexes(): + try: + graph.remove_edge(i, i) # Remove self-loops. + except rx.NoEdgeBetweenNodes: + continue + # Make sure the graph is connected by adding C_n + graph.add_edges_from_no_data([(i, i + 1) for i in range(len(graph) - 1)]) + + # Generate a randomized permutation. + rand_perm = random.permutation(graph.nodes()) + permutation = dict(zip(graph.nodes(), rand_perm)) + mapping = dict(itertools.islice(permutation.items(), 0, size, 2)) # Drop every 2nd element. + swaps = rx.graph_token_swapper(graph, permutation, 4, 4) + swap_permutation(mapping, swaps) + self.assertEqual({i: i for i in mapping.values()}, mapping) From 1bd168a605c2a0fa5e887c3e78bd37844b307709 Mon Sep 17 00:00:00 2001 From: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Date: Mon, 8 May 2023 14:32:11 -0600 Subject: [PATCH 02/37] Fix docs to work with Sphinx Theme 1.11 (#867) * Fix docs to work with Sphinx Theme 1.11 * Update docs/source/_templates/sidebar.html Co-authored-by: Matthew Treinish --- docs/source/_templates/layout.html | 300 +++++++++------------------ docs/source/_templates/page.html | 14 -- docs/source/_templates/sidebar.html | 99 +++++++++ docs/source/_templates/versions.html | 25 --- docs/source/conf.py | 12 +- docs/source/requirements.txt | 3 +- 6 files changed, 205 insertions(+), 248 deletions(-) delete mode 100644 docs/source/_templates/page.html create mode 100644 docs/source/_templates/sidebar.html delete mode 100644 docs/source/_templates/versions.html diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 102d6029e..515164542 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -1,48 +1,25 @@ -{# TEMPLATE VAR SETTINGS #} -{%- set url_root = pathto('', 1) %} -{%- if url_root == '#' %}{% set url_root = '' %}{% endif %} +{# Sphinx template variable setup #} {%- if not embedded and docstitle %} {%- set titlesuffix = " — "|safe + docstitle|e %} {%- else %} {%- set titlesuffix = "" %} {%- endif %} -{%- set lang_attr = 'en' if language == None else (language | replace('_', '-')) %} -{% import 'theme_variables.jinja' as theme_variables %} +{%- set lang_attr = 'en' if language == None else (language | replace('_', '-')) -%} - - + - - {{ metatags }} + - {% block htmltitle %} + {{ metatags }} + {%- block htmltitle %} {{ title|striptags|e }}{{ titlesuffix }} - {% endblock %} - - {# FAVICON #} - {% if favicon %} - - {% endif %} - {# CANONICAL URL #} - {% if theme_canonical_url %} - - {% endif %} - - {# CSS #} - - {# OPENSEARCH #} - {% if not embedded %} - {% if use_opensearch %} - - {% endif %} - - {% endif %} + {%- endblock %} + {%- if favicon_url %} + + {%- endif %} - - + {#- CSS #} {%- for css in css_files %} {%- if css|attr("rel") %} @@ -51,124 +28,108 @@ {%- endif %} {%- endfor %} {%- for cssfile in extra_css_files %} - - {%- endfor %} - - {%- block linktags %} - {%- if hasdoc('about') %} - - {%- endif %} - {%- if hasdoc('genindex') %} - - {%- endif %} - {%- if hasdoc('search') %} - - {%- endif %} - {%- if hasdoc('copyright') %} - - {%- endif %} - {%- if next %} - - {%- endif %} - {%- if prev %} - - {%- endif %} - {%- endblock %} - {%- block extrahead %} {% endblock %} - - {# Keep modernizr in head - http://modernizr.com/docs/#installing #} - + + {%- endfor -%} + + + {%- if analytics_enabled %} + + + + {%- endif -%} - - {% block extrabody %} {% endblock %} - {# SIDE NAV, TOGGLES ON MOBILE #} + - {% include "versions.html" %} + + {% include "languages.html" %} + - + + {% include "sidebar.html" %}
+ +
{% include "breadcrumbs.html" %}
- -
- Shortcuts -
+
-
+
{%- block content %} - {% if theme_style_external_links|tobool %} -
- {% if not embedded %} +{%- block footer %} {% endblock %} - {% if sphinx_version >= "1.8.0" %} - - {%- for scriptfile in script_files %} - {{ js_tag(scriptfile) }} - {%- endfor %} - {% else %} - - {%- for scriptfile in script_files %} - - {%- endfor %} - {% endif %} + + + {%- for scriptfile in script_files %} + {{ js_tag(scriptfile) }} + {%- endfor %} - {% endif %} + + + + - - - - + + -{%- block footer %} {% endblock %} - -
-
-
- - - -
-
-
-
- - -
-
-
- - -
- - - - - + diff --git a/docs/source/_templates/page.html b/docs/source/_templates/page.html deleted file mode 100644 index 429a7dedd..000000000 --- a/docs/source/_templates/page.html +++ /dev/null @@ -1,14 +0,0 @@ -{% extends "!page.html" %} - -{% block footer %} - -{% endblock %} \ No newline at end of file diff --git a/docs/source/_templates/sidebar.html b/docs/source/_templates/sidebar.html new file mode 100644 index 000000000..13c36ae8a --- /dev/null +++ b/docs/source/_templates/sidebar.html @@ -0,0 +1,99 @@ + + + +{% if expandable_sidebar %} + +{% endif %} + + + \ No newline at end of file diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html deleted file mode 100644 index 7e445164d..000000000 --- a/docs/source/_templates/versions.html +++ /dev/null @@ -1,25 +0,0 @@ -
- - {{ version_label }} - - -
-
-
Versions
-
Current Release
-
Development
-
Previous Releases
- {% for version in version_list %} -
{{ version }}
- {% endfor %} -
-
- -
diff --git a/docs/source/conf.py b/docs/source/conf.py index be034aca7..23faf0b94 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,8 +19,8 @@ # General configuration: -project = u'rustworkx' -copyright = u'2021, rustworkx Contributors' +project = 'rustworkx' +copyright = '2021, rustworkx Contributors' # The short X.Y version. @@ -40,6 +40,7 @@ 'sphinx.ext.intersphinx', 'sphinxemoji.sphinxemoji', 'sphinx_reredirects', + 'qiskit_sphinx_theme', ] html_static_path = ['_static'] templates_path = ['_templates'] @@ -86,30 +87,25 @@ .. note:: - This is the documnetation for the current state of the development branch + This is the documentation for the current state of the development branch of rustworkx. The documentation or APIs here can change prior to being released. """ # HTML Output Options - html_theme = 'qiskit_sphinx_theme' - html_theme_options = { 'logo_only': False, 'display_version': True, 'prev_next_buttons_location': 'bottom', 'style_external_links': True, } - htmlhelp_basename = 'rustworkx' # Latex options - latex_elements = {} - latex_documents = [ ('index', 'rustworkx.tex', u'rustworkx Documentation', u'rustworkx Contributors', 'manual'), diff --git a/docs/source/requirements.txt b/docs/source/requirements.txt index a960d2883..57816784a 100644 --- a/docs/source/requirements.txt +++ b/docs/source/requirements.txt @@ -1,11 +1,10 @@ m2r2 sphinx>=3.0.0 -sphinx_rtd_theme jupyter-sphinx pydot pillow>=4.2.1 reno>=3.4.0 -qiskit-sphinx-theme>=1.7 +qiskit-sphinx-theme~=1.11.1 matplotlib>=3.4 sphinx-reredirects sphinxemoji From 5a3f9b386b17728fd235982ddcf1e44d5f1b596d Mon Sep 17 00:00:00 2001 From: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Date: Mon, 8 May 2023 15:16:17 -0600 Subject: [PATCH 03/37] Turn off CI for forks (#868) Co-authored-by: Matthew Treinish --- .github/workflows/docs_dev.yml | 1 + .github/workflows/main.yml | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/docs_dev.yml b/.github/workflows/docs_dev.yml index 9e1c93204..c55b0aab8 100644 --- a/.github/workflows/docs_dev.yml +++ b/.github/workflows/docs_dev.yml @@ -5,6 +5,7 @@ on: jobs: deploy: + if: github.repository_owner == 'Qiskit' runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4f18920f2..56db02926 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,6 +10,7 @@ concurrency: cancel-in-progress: true jobs: build_lint: + if: github.repository_owner == 'Qiskit' name: Build, rustfmt, and python lint runs-on: ubuntu-latest steps: @@ -49,6 +50,7 @@ jobs: name: rustworkx_core_docs path: target/doc/rustworkx_core tests: + if: github.repository_owner == 'Qiskit' needs: [build_lint] name: python${{ matrix.python-version }}-${{ matrix.platform.python-architecture }} ${{ matrix.platform.os }} ${{ matrix.msrv }} runs-on: ${{ matrix.platform.os }} @@ -87,6 +89,7 @@ jobs: - name: 'Run tests' run: tox -epy tests_stubs: + if: github.repository_owner == 'Qiskit' needs: [tests] name: python-stubs-${{ matrix.python-version }} runs-on: ubuntu-latest @@ -107,6 +110,7 @@ jobs: - name: 'Run rustworkx stub tests' run: tox -estubs tests_retworkx_compat: + if: github.repository_owner == 'Qiskit' needs: [build_lint] name: python${{ matrix.python-version }}-${{ matrix.platform.python-architecture }} ${{ matrix.platform.os }} ${{ matrix.msrv }} runs-on: ${{ matrix.platform.os }} @@ -147,6 +151,7 @@ jobs: cd tests stestr run -t ./retworkx_backwards_compat coverage: + if: github.repository_owner == 'Qiskit' needs: [tests] name: Coverage runs-on: ubuntu-latest @@ -189,6 +194,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: coveralls.info docs: + if: github.repository_owner == 'Qiskit' needs: [tests] name: Build Docs runs-on: ubuntu-latest From 8686896a9668c61e8b3468836a8552cde2a7dba5 Mon Sep 17 00:00:00 2001 From: Binh Vu Date: Tue, 9 May 2023 12:54:16 -0700 Subject: [PATCH 04/37] Fix pickle/deepcopy not preserve original edge indices (#589) * fix issue #585 that pickling graph & digraph do not preserve original edge index * fix clippy lints - collapsible_else_if * Simplify logic in __setstate__ * Add release note * Fix lint --------- Co-authored-by: Matthew Treinish --- ...-edge-indices-pickle-83fddf149441fa9f.yaml | 10 + src/digraph.rs | 231 +++++++++++++----- src/graph.rs | 223 ++++++++++++----- tests/rustworkx_tests/digraph/test_pickle.py | 41 ++++ tests/rustworkx_tests/graph/test_pickle.py | 41 ++++ 5 files changed, 418 insertions(+), 128 deletions(-) create mode 100644 releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml create mode 100644 tests/rustworkx_tests/digraph/test_pickle.py create mode 100644 tests/rustworkx_tests/graph/test_pickle.py diff --git a/releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml b/releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml new file mode 100644 index 000000000..238fe3cf6 --- /dev/null +++ b/releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml @@ -0,0 +1,10 @@ +--- +fixes: + - | + Fixed an issue when using ``copy.deepcopy()`` on :class:`~.PyDiGraph` and + :class:`~.PyGraph` objects when there were removed edges from the graph + object. Previously, if there were any holes in the edge indices caused by + the removal the output copy of the graph object would incorrectly have + flatten the indices. This has been corrected so that the edge indices are + recreated exactly after a ``deepcopy()``. + Fixed `#585 `__ diff --git a/src/digraph.rs b/src/digraph.rs index 9339869fc..ba81cbae5 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -43,7 +43,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ - GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable, + EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, Visitable, }; @@ -298,97 +298,196 @@ impl PyDiGraph { } fn __getstate__(&self, py: Python) -> PyResult { + let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); + + // save nodes to a list along with its index + for node_idx in self.graph.node_indices() { + let node_data = self.graph.node_weight(node_idx).unwrap(); + nodes.push((node_idx.index(), node_data).to_object(py)); + } + + // edges are saved with none (deleted edges) instead of their index to save space + for i in 0..self.graph.edge_bound() { + let idx = EdgeIndex::new(i); + let edge = match self.graph.edge_weight(idx) { + Some(edge_w) => { + let endpoints = self.graph.edge_endpoints(idx).unwrap(); + (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py) + } + None => py.None(), + }; + edges.push(edge); + } + let out_dict = PyDict::new(py); - let node_dict = PyDict::new(py); - let mut out_list: Vec = Vec::with_capacity(self.graph.edge_count()); - out_dict.set_item("nodes", node_dict)?; + let nodes_lst: PyObject = PyList::new(py, nodes).into(); + let edges_lst: PyObject = PyList::new(py, edges).into(); + out_dict.set_item("nodes", nodes_lst)?; + out_dict.set_item("edges", edges_lst)?; out_dict.set_item("nodes_removed", self.node_removed)?; out_dict.set_item("multigraph", self.multigraph)?; out_dict.set_item("attrs", self.attrs.clone_ref(py))?; out_dict.set_item("check_cycle", self.check_cycle)?; - let dir = petgraph::Direction::Incoming; - for node_index in self.graph.node_indices() { - let node_data = self.graph.node_weight(node_index).unwrap(); - node_dict.set_item(node_index.index(), node_data)?; - for edge in self.graph.edges_directed(node_index, dir) { - let edge_w = edge.weight(); - let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py); - out_list.push(triplet); - } - } - let py_out_list: PyObject = PyList::new(py, out_list).into(); - out_dict.set_item("edges", py_out_list)?; Ok(out_dict.into()) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + let dict_state = state.downcast::(py)?; + let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::()?; + let edges_lst = dict_state.get_item("edges").unwrap().downcast::()?; self.graph = StablePyGraph::::new(); let dict_state = state.downcast::(py)?; - - let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; - let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let nodes_removed_raw = dict_state - .get_item("nodes_removed") - .unwrap() - .downcast::()?; - self.node_removed = nodes_removed_raw.extract()?; - let multigraph_raw = dict_state + self.multigraph = dict_state .get_item("multigraph") .unwrap() - .downcast::()?; - self.multigraph = multigraph_raw.extract()?; + .downcast::()? + .extract()?; + self.node_removed = dict_state + .get_item("nodes_removed") + .unwrap() + .downcast::()? + .extract()?; let attrs = match dict_state.get_item("attrs") { Some(attr) => attr.into(), None => py.None(), }; self.attrs = attrs; - let check_cycle_raw = dict_state + self.check_cycle = dict_state .get_item("check_cycle") .unwrap() - .downcast::()?; - self.check_cycle = check_cycle_raw.extract()?; - let mut node_indices: Vec = Vec::new(); - for raw_index in nodes_dict.keys() { - let tmp_index = raw_index.downcast::()?; - node_indices.push(tmp_index.extract()?); - } - if node_indices.is_empty() { + .downcast::()? + .extract()?; + + // graph is empty, stop early + if nodes_lst.is_empty() { return Ok(()); } - let max_index: usize = *node_indices.iter().max().unwrap(); - if max_index + 1 != node_indices.len() { - self.node_removed = true; - } - let mut tmp_nodes: Vec = Vec::new(); - let mut node_count: usize = 0; - while max_index >= self.graph.node_bound() { - match nodes_dict.get_item(node_count) { - Some(raw_data) => { - self.graph.add_node(raw_data.into()); - } - None => { + + if !self.node_removed { + for item in nodes_lst.iter() { + let node_w = item + .downcast::() + .unwrap() + .get_item(1) + .unwrap() + .extract() + .unwrap(); + self.graph.add_node(node_w); + } + } else if nodes_lst.len() == 1 { + // graph has only one node, handle logic here to save one if in the loop later + let item = nodes_lst + .get_item(0) + .unwrap() + .downcast::() + .unwrap(); + let node_idx: usize = item.get_item(0).unwrap().extract().unwrap(); + let node_w = item.get_item(1).unwrap().extract().unwrap(); + + for _i in 0..node_idx { + self.graph.add_node(py.None()); + } + self.graph.add_node(node_w); + for i in 0..node_idx { + self.graph.remove_node(NodeIndex::new(i)); + } + } else { + let last_item = nodes_lst + .get_item(nodes_lst.len() - 1) + .unwrap() + .downcast::() + .unwrap(); + + // use a pointer to iter the node list + let mut pointer = 0; + let mut next_node_idx: usize = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + // list of temporary nodes that will be removed later to re-create holes + let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); + let mut tmp_nodes: Vec = + Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); + + for i in 0..nodes_lst.len() + 1 { + if i < next_node_idx { + // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); + } else { + // add node to the graph, and update the next available node index + let item = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap(); + + let node_w = item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(node_w); + pointer += 1; + if pointer < nodes_lst.len() { + next_node_idx = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + } } - }; - node_count += 1; - } - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); - } - for raw_edge in edges_list.iter() { - let edge = raw_edge.downcast::()?; - let raw_p_index = edge.get_item(0)?.downcast::()?; - let p_index: usize = raw_p_index.extract()?; - let raw_c_index = edge.get_item(1)?.downcast::()?; - let c_index: usize = raw_c_index.extract()?; - let edge_data = edge.get_item(2)?; - self.graph.add_edge( - NodeIndex::new(p_index), - NodeIndex::new(c_index), - edge_data.into(), - ); + } + // Remove any temporary nodes we added + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); + } } + + // to ensure O(1) on edge deletion, use a temporary node to store missing edges + let tmp_node = self.graph.add_node(py.None()); + + for item in edges_lst { + if item.is_none() { + // add a temporary edge that will be deleted later to re-create the hole + self.graph.add_edge(tmp_node, tmp_node, py.None()); + } else { + let triple = item.downcast::().unwrap(); + let edge_p: usize = triple + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_c: usize = triple + .get_item(1) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_w = triple.get_item(2).unwrap().extract().unwrap(); + self.graph + .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); + } + } + + // remove the temporary node will remove all deleted edges in bulk, + // the cost is equal to the number of edges + self.graph.remove_node(tmp_node); + Ok(()) } diff --git a/src/graph.rs b/src/graph.rs index 3b0f2c43a..7857cef9c 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -46,7 +46,7 @@ use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ - GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable, + EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, }; /// A class for creating undirected graphs @@ -192,88 +192,187 @@ impl PyGraph { } fn __getstate__(&self, py: Python) -> PyResult { + let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); + + // save nodes to a list along with its index + for node_idx in self.graph.node_indices() { + let node_data = self.graph.node_weight(node_idx).unwrap(); + nodes.push((node_idx.index(), node_data).to_object(py)); + } + + // edges are saved with none (deleted edges) instead of their index to save space + for i in 0..self.graph.edge_bound() { + let idx = EdgeIndex::new(i); + let edge = match self.graph.edge_weight(idx) { + Some(edge_w) => { + let endpoints = self.graph.edge_endpoints(idx).unwrap(); + (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py) + } + None => py.None(), + }; + edges.push(edge); + } + let out_dict = PyDict::new(py); - let node_dict = PyDict::new(py); - let mut out_list: Vec = Vec::with_capacity(self.graph.edge_count()); - out_dict.set_item("nodes", node_dict)?; + let nodes_lst: PyObject = PyList::new(py, nodes).into(); + let edges_lst: PyObject = PyList::new(py, edges).into(); + out_dict.set_item("nodes", nodes_lst)?; + out_dict.set_item("edges", edges_lst)?; out_dict.set_item("nodes_removed", self.node_removed)?; out_dict.set_item("multigraph", self.multigraph)?; out_dict.set_item("attrs", self.attrs.clone_ref(py))?; - for node_index in self.graph.node_indices() { - let node_data = self.graph.node_weight(node_index).unwrap(); - node_dict.set_item(node_index.index(), node_data)?; - } - for edge in self.graph.edge_indices() { - let edge_w = self.graph.edge_weight(edge); - let endpoints = self.graph.edge_endpoints(edge).unwrap(); - - let triplet = (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py); - out_list.push(triplet); - } - let py_out_list: PyObject = PyList::new(py, out_list).into(); - out_dict.set_item("edges", py_out_list)?; Ok(out_dict.into()) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - self.graph = StablePyGraph::::default(); let dict_state = state.downcast::(py)?; - let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; - let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let nodes_removed_raw = dict_state - .get_item("nodes_removed") - .unwrap() - .downcast::()?; - self.node_removed = nodes_removed_raw.extract()?; - let multigraph_raw = dict_state + let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::()?; + let edges_lst = dict_state.get_item("edges").unwrap().downcast::()?; + + self.graph = StablePyGraph::::default(); + self.multigraph = dict_state .get_item("multigraph") .unwrap() - .downcast::()?; - self.multigraph = multigraph_raw.extract()?; - let attrs = match dict_state.get_item("attrs") { + .downcast::()? + .extract()?; + self.node_removed = dict_state + .get_item("nodes_removed") + .unwrap() + .downcast::()? + .extract()?; + self.attrs = match dict_state.get_item("attrs") { Some(attr) => attr.into(), None => py.None(), }; - self.attrs = attrs; - - let mut node_indices: Vec = Vec::new(); - for raw_index in nodes_dict.keys() { - let tmp_index = raw_index.downcast::()?; - node_indices.push(tmp_index.extract()?); - } - if node_indices.is_empty() { + // graph is empty, stop early + if nodes_lst.is_empty() { return Ok(()); } - let max_index: usize = *node_indices.iter().max().unwrap(); - let mut tmp_nodes: Vec = Vec::new(); - let mut node_count: usize = 0; - while max_index >= self.graph.node_bound() { - match nodes_dict.get_item(node_count) { - Some(raw_data) => { - self.graph.add_node(raw_data.into()); - } - None => { + + if !self.node_removed { + for item in nodes_lst.iter() { + let node_w = item + .downcast::() + .unwrap() + .get_item(1) + .unwrap() + .extract() + .unwrap(); + self.graph.add_node(node_w); + } + } else if nodes_lst.len() == 1 { + // graph has only one node, handle logic here to save one if in the loop later + let item = nodes_lst + .get_item(0) + .unwrap() + .downcast::() + .unwrap(); + let node_idx: usize = item.get_item(0).unwrap().extract().unwrap(); + let node_w = item.get_item(1).unwrap().extract().unwrap(); + + for _i in 0..node_idx { + self.graph.add_node(py.None()); + } + self.graph.add_node(node_w); + for i in 0..node_idx { + self.graph.remove_node(NodeIndex::new(i)); + } + } else { + let last_item = nodes_lst + .get_item(nodes_lst.len() - 1) + .unwrap() + .downcast::() + .unwrap(); + + // use a pointer to iter the node list + let mut pointer = 0; + let mut next_node_idx: usize = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + // list of temporary nodes that will be removed later to re-create holes + let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); + let mut tmp_nodes: Vec = + Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); + + for i in 0..nodes_lst.len() + 1 { + if i < next_node_idx { + // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); + } else { + // add node to the graph, and update the next available node index + let item = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap(); + + let node_w = item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(node_w); + pointer += 1; + if pointer < nodes_lst.len() { + next_node_idx = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + } } - }; - node_count += 1; - } - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); + } + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); + } } - for raw_edge in edges_list.iter() { - let edge = raw_edge.downcast::()?; - let raw_p_index = edge.get_item(0)?.downcast::()?; - let parent: usize = raw_p_index.extract()?; - let p_index = NodeIndex::new(parent); - let raw_c_index = edge.get_item(1)?.downcast::()?; - let child: usize = raw_c_index.extract()?; - let c_index = NodeIndex::new(child); - let edge_data = edge.get_item(2)?; - - self.graph.add_edge(p_index, c_index, edge_data.into()); + + // to ensure O(1) on edge deletion, use a temporary node to store missing edges + let tmp_node = self.graph.add_node(py.None()); + + for item in edges_lst { + if item.is_none() { + // add a temporary edge that will be deleted later to re-create the hole + self.graph.add_edge(tmp_node, tmp_node, py.None()); + } else { + let triple = item.downcast::().unwrap(); + let edge_p: usize = triple + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_c: usize = triple + .get_item(1) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_w = triple.get_item(2).unwrap().extract().unwrap(); + self.graph + .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); + } } + + // remove the temporary node will remove all deleted edges in bulk, + // the cost is equal to the number of edges + self.graph.remove_node(tmp_node); + Ok(()) } diff --git a/tests/rustworkx_tests/digraph/test_pickle.py b/tests/rustworkx_tests/digraph/test_pickle.py new file mode 100644 index 000000000..306fd119c --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_pickle.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import pickle +import unittest + +import rustworkx as rx + + +class TestPickleDiGraph(unittest.TestCase): + def test_noweight_graph(self): + g = rx.PyDAG() + for i in range(4): + g.add_node(None) + g.add_edges_from_no_data([(0, 1), (1, 2), (3, 0), (3, 1)]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual([None, None, None], gprime.nodes()) + self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) + + def test_weight_graph(self): + g = rx.PyDAG() + g.add_nodes_from(["A", "B", "C", "D"]) + g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual(["B", "C", "D"], gprime.nodes()) + self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map())) diff --git a/tests/rustworkx_tests/graph/test_pickle.py b/tests/rustworkx_tests/graph/test_pickle.py new file mode 100644 index 000000000..44220f113 --- /dev/null +++ b/tests/rustworkx_tests/graph/test_pickle.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import pickle +import unittest + +import rustworkx as rx + + +class TestPickleGraph(unittest.TestCase): + def test_noweight_graph(self): + g = rx.PyGraph() + for i in range(4): + g.add_node(None) + g.add_edges_from_no_data([(0, 1), (1, 2), (3, 0), (3, 1)]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual([None, None, None], gprime.nodes()) + self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) + + def test_weight_graph(self): + g = rx.PyGraph() + g.add_nodes_from(["A", "B", "C", "D"]) + g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual(["B", "C", "D"], gprime.nodes()) + self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map())) From c17eea558841d2b473226917cd3da90e24a616ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 May 2023 21:59:17 +0000 Subject: [PATCH 05/37] Bump serde from 1.0.160 to 1.0.162 (#863) Bumps [serde](https://github.com/serde-rs/serde) from 1.0.160 to 1.0.162. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.160...1.0.162) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Matthew Treinish --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6e4267fc2..b9dbbb3cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -581,18 +581,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.160" +version = "1.0.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +checksum = "71b2f6e1ab5c2b98c05f0f35b236b22e8df7ead6ffbf51d7808da7f8817e7ab6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.160" +version = "1.0.162" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" +checksum = "a2a0814352fd64b58489904a44ea8d90cb1a91dcb6b4f5ebabc32c8318e93cb6" dependencies = [ "proc-macro2", "quote", From a16c18d63635961075037c614559f3b9214434f9 Mon Sep 17 00:00:00 2001 From: matanco64 <38103422+matanco64@users.noreply.github.com> Date: Wed, 10 May 2023 22:13:48 +0300 Subject: [PATCH 06/37] Add reverse inplace function for digraph (#853) * added a reverse_inplace function in digraph, the function reverses the direction of the edges in the digraph implemented by switching the indices of the nodes in an edge. * added python tests for the reverse_inplace function. testing a simple case and a case for a large graph. * ran rust fmt and clippy, also added more detailed documentation * rename reverse_inplace to reverse * change excepts to unwraps (If this fails is because of PyO3. It panics and there is not much point in printing a message) * added tests for empty graph and graph with node removed in the middle * added interface signature for IDEs * ran cargo fmt * Fix doc syntax --------- Co-authored-by: Matthew Treinish --- rustworkx/digraph.pyi | 1 + src/digraph.rs | 33 +++++++++++++++++ tests/rustworkx_tests/digraph/test_edges.py | 41 +++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/rustworkx/digraph.pyi b/rustworkx/digraph.pyi index 56b977d98..13735b5fc 100644 --- a/rustworkx/digraph.pyi +++ b/rustworkx/digraph.pyi @@ -166,6 +166,7 @@ class PyDiGraph(Generic[S, T]): deliminator: Optional[str] = ..., weight_fn: Optional[Callable[[T], str]] = ..., ) -> None: ... + def reverse(self) -> None: ... def __delitem__(self, idx: int, /) -> None: ... def __getitem__(self, idx: int, /) -> S: ... def __getstate__(self) -> Any: ... diff --git a/src/digraph.rs b/src/digraph.rs index ba81cbae5..c4d7518d0 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2819,6 +2819,39 @@ impl PyDiGraph { self.clone() } + /// Reverse the direction of all edges in the graph, in place. + /// + /// This method modifies the graph instance to reverse the direction of all edges. + /// It does so by iterating over all edges in the graph and removing each edge, + /// then adding a new edge in the opposite direction with the same weight. + /// + /// For Example:: + /// + /// import rustworkx as rx + /// + /// graph = rx.PyDiGraph() + /// + /// # Generate a path directed path graph with weights + /// graph.extend_from_weighted_edge_list([ + /// (0, 1, 3), + /// (1, 2, 5), + /// (2, 3, 2), + /// ]) + /// # Reverse edges + /// graph.reverse() + /// + /// assert graph.weighted_edge_list() == [(3, 2, 2), (2, 1, 5), (1, 0, 3)]; + #[pyo3(text_signature = "(self)")] + pub fn reverse(&mut self, py: Python) { + let indices = self.graph.edge_indices().collect::>(); + for idx in indices { + let (source_node, dest_node) = self.graph.edge_endpoints(idx).unwrap(); + let weight = self.graph.edge_weight(idx).unwrap().clone_ref(py); + self.graph.remove_edge(idx); + self.graph.add_edge(dest_node, source_node, weight); + } + } + /// Return the number of nodes in the graph fn __len__(&self) -> PyResult { Ok(self.graph.node_count()) diff --git a/tests/rustworkx_tests/digraph/test_edges.py b/tests/rustworkx_tests/digraph/test_edges.py index 54a448bfe..2d9f56ae5 100644 --- a/tests/rustworkx_tests/digraph/test_edges.py +++ b/tests/rustworkx_tests/digraph/test_edges.py @@ -962,3 +962,44 @@ def test_extend_from_weighted_edge_list(self): graph.extend_from_weighted_edge_list(edge_list) self.assertEqual(len(graph), 4) self.assertEqual(["a", "b", "c", "d", "e"], graph.edges()) + + def test_reverse_graph(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from([i for i in range(4)]) + edge_list = [ + (0, 1, "a"), + (1, 2, "b"), + (0, 2, "c"), + (2, 3, "d"), + (0, 3, "e"), + ] + graph.add_edges_from(edge_list) + graph.reverse() + self.assertEqual([(1, 0), (2, 1), (2, 0), (3, 2), (3, 0)], graph.edge_list()) + + def test_reverse_large_graph(self): + LARGE_AMOUNT_OF_NODES = 10000000 + + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(range(LARGE_AMOUNT_OF_NODES)) + edge_list = list(zip(range(LARGE_AMOUNT_OF_NODES), range(1, LARGE_AMOUNT_OF_NODES))) + weighted_edge_list = [(s, d, "a") for s, d in edge_list] + graph.add_edges_from(weighted_edge_list) + graph.reverse() + reversed_edge_list = [(d, s) for s, d in edge_list] + self.assertEqual(reversed_edge_list, graph.edge_list()) + + def test_reverse_empty_graph(self): + graph = rustworkx.PyDiGraph() + edges_before = graph.edge_list() + graph.reverse() + self.assertEqual(graph.edge_list(), edges_before) + + def test_removed_middle_node_reverse(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(list(range(5))) + edge_list = [(0, 1), (2, 1), (1, 3), (3, 4), (4, 0)] + graph.extend_from_edge_list(edge_list) + graph.remove_node(1) + graph.reverse() + self.assertEqual(graph.edge_list(), [(4, 3), (0, 4)]) From 70e49653a4d7ff0dae62d3f16d5da343ba2e77da Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 May 2023 15:47:09 +0000 Subject: [PATCH 07/37] Bump serde from 1.0.162 to 1.0.163 (#869) Bumps [serde](https://github.com/serde-rs/serde) from 1.0.162 to 1.0.163. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.162...v1.0.163) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b9dbbb3cf..a3c081365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -581,18 +581,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.162" +version = "1.0.163" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b2f6e1ab5c2b98c05f0f35b236b22e8df7ead6ffbf51d7808da7f8817e7ab6" +checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.162" +version = "1.0.163" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2a0814352fd64b58489904a44ea8d90cb1a91dcb6b4f5ebabc32c8318e93cb6" +checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" dependencies = [ "proc-macro2", "quote", From c917bca398ab892636582831d7a3390b64a9fbb3 Mon Sep 17 00:00:00 2001 From: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Date: Fri, 12 May 2023 10:41:12 -0600 Subject: [PATCH 08/37] Remove unncessary CSS files (#871) --- docs/source/_static/custom.css | 27 ----- docs/source/_static/gallery.css | 195 -------------------------------- docs/source/_static/style.css | 12 -- docs/source/conf.py | 2 - 4 files changed, 236 deletions(-) delete mode 100644 docs/source/_static/custom.css delete mode 100644 docs/source/_static/gallery.css delete mode 100644 docs/source/_static/style.css diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css deleted file mode 100644 index 3f981b658..000000000 --- a/docs/source/_static/custom.css +++ /dev/null @@ -1,27 +0,0 @@ -.toggle .header { - display: block; - clear: both; - background-color: #785EF0; - color: #f9f9f9; - height: 40px; - padding-top: 10px; - padding-left: 5px; - margin-bottom: 20px; -} - -.toggle .header:before { - float: left; - content: "▶ "; - font-size: 20px; - -} - -.toggle .header.open:before { - float: left; - content: "▼ "; - font-size: 20px; -} - -.toggle{ - background: #FBFBFB; -} diff --git a/docs/source/_static/gallery.css b/docs/source/_static/gallery.css deleted file mode 100644 index e5f79a8c5..000000000 --- a/docs/source/_static/gallery.css +++ /dev/null @@ -1,195 +0,0 @@ -/* -Sphinx-Gallery has compatible CSS to fix default sphinx themes -Tested for Sphinx 1.3.1 for all themes: default, alabaster, sphinxdoc, -scrolls, agogo, traditional, nature, haiku, pyramid -Tested for Read the Docs theme 0.1.7 */ -.sphx-glr-thumbcontainer { - background: #fff; - border: solid #fff 1px; - -moz-border-radius: 5px; - -webkit-border-radius: 5px; - border-radius: 5px; - box-shadow: none; - float: left; - margin: 5px; - min-height: 230px; - padding-top: 5px; - position: relative; -} -.sphx-glr-thumbcontainer:hover { - border: solid #6200EE 1px; - box-shadow: 0 0 15px rgba(163, 142, 202, 0.5); -} -.sphx-glr-thumbcontainer a.internal { - bottom: 0; - display: block; - left: 0; - padding: 150px 10px 0; - position: absolute; - right: 0; - top: 0; -} -/* Next one is to avoid Sphinx traditional theme to cover all the -thumbnail with its default link Background color */ -.sphx-glr-thumbcontainer a.internal:hover { - background-color: transparent; -} - -.sphx-glr-thumbcontainer p { - margin: 0 0 .1em 0; -} -.sphx-glr-thumbcontainer .figure { - margin: 10px; - width: 160px; -} -.sphx-glr-thumbcontainer img { - display: inline; - max-height: 112px; - max-width: 160px; -} -.sphx-glr-thumbcontainer[tooltip]:hover:after { - background: rgba(0, 0, 0, 0.8); - -webkit-border-radius: 5px; - -moz-border-radius: 5px; - border-radius: 5px; - color: #fff; - content: attr(tooltip); - left: 95%; - padding: 5px 15px; - position: absolute; - z-index: 98; - width: 220px; - bottom: 52%; -} -.sphx-glr-thumbcontainer[tooltip]:hover:before { - border: solid; - border-color: #333 transparent; - border-width: 18px 0 0 20px; - bottom: 58%; - content: ''; - left: 85%; - position: absolute; - z-index: 99; -} - -.sphx-glr-script-out { - color: #888; - margin: 0; -} -p.sphx-glr-script-out { - padding-top: 0.7em; -} -.sphx-glr-script-out .highlight { - background-color: transparent; - margin-left: 2.5em; - margin-top: -2.1em; -} -.sphx-glr-script-out .highlight pre { - background-color: #fafae2; - border: 0; - max-height: 30em; - overflow: auto; - padding-left: 1ex; - margin: 0px; - word-break: break-word; -} -.sphx-glr-script-out + p { - margin-top: 1.8em; -} -blockquote.sphx-glr-script-out { - margin-left: 0pt; -} -.sphx-glr-script-out.highlight-pytb .highlight pre { - color: #000; - background-color: #ffe4e4; - border: 1px solid #f66; - margin-top: 10px; - padding: 7px; -} - -div.sphx-glr-footer { - text-align: center; -} - -div.sphx-glr-download { - margin: 1em auto; - vertical-align: middle; -} - -div.sphx-glr-download a { - background-color: #ffc; - background-image: linear-gradient(to bottom, #FFC, #d5d57e); - border-radius: 4px; - border: 1px solid #c2c22d; - color: #000; - display: inline-block; - font-weight: bold; - padding: 1ex; - text-align: center; -} - -div.sphx-glr-download code.download { - display: inline-block; - white-space: normal; - word-break: normal; - overflow-wrap: break-word; - /* border and background are given by the enclosing 'a' */ - border: none; - background: none; -} - -div.sphx-glr-download a:hover { - box-shadow: inset 0 1px 0 rgba(255,255,255,.1), 0 1px 5px rgba(0,0,0,.25); - text-decoration: none; - background-image: none; - background-color: #d5d57e; -} - -.sphx-glr-example-title > :target::before { - display: block; - content: ""; - margin-top: -50px; - height: 50px; - visibility: hidden; -} - -ul.sphx-glr-horizontal { - list-style: none; - padding: 0; -} -ul.sphx-glr-horizontal li { - display: inline; -} -ul.sphx-glr-horizontal img { - height: auto !important; -} - -.sphx-glr-single-img { - margin: auto; - display: block; - max-width: 100%; -} - -.sphx-glr-multi-img { - max-width: 42%; - height: auto; -} - -p.sphx-glr-signature a.reference.external { - -moz-border-radius: 5px; - -webkit-border-radius: 5px; - border-radius: 5px; - padding: 3px; - font-size: 75%; - text-align: right; - margin-left: auto; - display: table; -} - -.sphx-glr-clear{ - clear: both; -} - -a.sphx-glr-backref-instance { - text-decoration: none; -} diff --git a/docs/source/_static/style.css b/docs/source/_static/style.css deleted file mode 100644 index 84980a3a4..000000000 --- a/docs/source/_static/style.css +++ /dev/null @@ -1,12 +0,0 @@ -.wy-nav-content { - max-width: 90% !important; -} - -.wy-side-scroll { - background:#8c8c8c; -} - -.pre -{ -color:#BE8184; -} diff --git a/docs/source/conf.py b/docs/source/conf.py index 23faf0b94..61e13c1b8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -42,9 +42,7 @@ 'sphinx_reredirects', 'qiskit_sphinx_theme', ] -html_static_path = ['_static'] templates_path = ['_templates'] -html_css_files = ['style.css', 'custom.css'] pygments_style = 'colorful' From e538391a852760fc5a6d40eacdb47f90a1516dbf Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Fri, 12 May 2023 18:00:42 -0400 Subject: [PATCH 09/37] Check if nodes exist on `add_edge` methods (#862) * Add tests from the example * Fix bug * Fix tests * Add release note * Update release note * Apply suggestions from code review Co-authored-by: Matthew Treinish * Fix docs to work with Sphinx Theme 1.11 (#867) * Fix docs to work with Sphinx Theme 1.11 * Update docs/source/_templates/sidebar.html Co-authored-by: Matthew Treinish * Turn off CI for forks (#868) Co-authored-by: Matthew Treinish * Fix pickle/deepcopy not preserve original edge indices (#589) * fix issue #585 that pickling graph & digraph do not preserve original edge index * fix clippy lints - collapsible_else_if * Simplify logic in __setstate__ * Add release note * Fix lint --------- Co-authored-by: Matthew Treinish * Bump serde from 1.0.160 to 1.0.162 (#863) Bumps [serde](https://github.com/serde-rs/serde) from 1.0.160 to 1.0.162. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.160...1.0.162) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Matthew Treinish * Add reverse inplace function for digraph (#853) * added a reverse_inplace function in digraph, the function reverses the direction of the edges in the digraph implemented by switching the indices of the nodes in an edge. * added python tests for the reverse_inplace function. testing a simple case and a case for a large graph. * ran rust fmt and clippy, also added more detailed documentation * rename reverse_inplace to reverse * change excepts to unwraps (If this fails is because of PyO3. It panics and there is not much point in printing a message) * added tests for empty graph and graph with node removed in the middle * added interface signature for IDEs * ran cargo fmt * Fix doc syntax --------- Co-authored-by: Matthew Treinish * Bump serde from 1.0.162 to 1.0.163 (#869) Bumps [serde](https://github.com/serde-rs/serde) from 1.0.162 to 1.0.163. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.162...v1.0.163) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Extend fixes to add_edges_from and add_edges_from_no_data * Lower amount of nodes in test --------- Signed-off-by: dependabot[bot] Co-authored-by: Matthew Treinish Co-authored-by: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Co-authored-by: Binh Vu Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: matanco64 <38103422+matanco64@users.noreply.github.com> --- ...le-non-existent-edge-15d70cfe60c89ac2.yaml | 6 ++++ src/connectivity/mod.rs | 2 +- src/digraph.rs | 13 ++++---- src/graph.rs | 30 +++++++++++-------- src/tree.rs | 2 +- tests/rustworkx_tests/digraph/test_edges.py | 17 ++++++++++- tests/rustworkx_tests/graph/test_edges.py | 15 ++++++++++ 7 files changed, 63 insertions(+), 22 deletions(-) create mode 100644 releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml diff --git a/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml b/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml new file mode 100644 index 000000000..8f8e2acd2 --- /dev/null +++ b/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml @@ -0,0 +1,6 @@ +--- +fixes: + - | + :meth:`rustworkx.PyGraph.add_edge` and :meth:`rustworkx.PyDiGraph.add_edge` and now raises an + ``IndexError`` when one of the nodes does not exist in the graph. Previously, it caused the Python + interpreter to exit with a ``PanicException`` diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 005ad08af..e17d30733 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -398,7 +398,7 @@ pub fn graph_complement(py: Python, graph: &graph::PyGraph) -> PyResult PyResult { let p_index = NodeIndex::new(parent); let c_index = NodeIndex::new(child); + if !self.graph.contains_node(p_index) || !self.graph.contains_node(c_index) { + return Err(PyIndexError::new_err( + "One of the endpoints of the edge does not exist in graph", + )); + } let out_index = self._add_edge(p_index, c_index, edge)?; Ok(out_index) } @@ -1103,9 +1108,7 @@ impl PyDiGraph { ) -> PyResult> { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - let edge = self._add_edge(p_index, c_index, obj.2)?; + let edge = self.add_edge(obj.0, obj.1, obj.2)?; out_list.push(edge); } Ok(out_list) @@ -1129,9 +1132,7 @@ impl PyDiGraph { ) -> PyResult> { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - let edge = self._add_edge(p_index, c_index, py.None())?; + let edge = self.add_edge(obj.0, obj.1, py.None())?; out_list.push(edge); } Ok(out_list) diff --git a/src/graph.rs b/src/graph.rs index 7857cef9c..d57fbb321 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -847,10 +847,15 @@ impl PyGraph { /// of an existing edge with ``multigraph=False``) edge. /// :rtype: int #[pyo3(text_signature = "(self, node_a, node_b, edge, /)")] - pub fn add_edge(&mut self, node_a: usize, node_b: usize, edge: PyObject) -> usize { + pub fn add_edge(&mut self, node_a: usize, node_b: usize, edge: PyObject) -> PyResult { let p_index = NodeIndex::new(node_a); let c_index = NodeIndex::new(node_b); - self._add_edge(p_index, c_index, edge) + if !self.graph.contains_node(p_index) || !self.graph.contains_node(c_index) { + return Err(PyIndexError::new_err( + "One of the endpoints of the edge does not exist in graph", + )); + } + Ok(self._add_edge(p_index, c_index, edge)) } /// Add new edges to the graph. @@ -869,14 +874,15 @@ impl PyGraph { /// :returns: A list of int indices of the newly created edges /// :rtype: list #[pyo3(text_signature = "(self, obj_list, /)")] - pub fn add_edges_from(&mut self, obj_list: Vec<(usize, usize, PyObject)>) -> EdgeIndices { + pub fn add_edges_from( + &mut self, + obj_list: Vec<(usize, usize, PyObject)>, + ) -> PyResult { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - out_list.push(self._add_edge(p_index, c_index, obj.2)); + out_list.push(self.add_edge(obj.0, obj.1, obj.2)?); } - EdgeIndices { edges: out_list } + Ok(EdgeIndices { edges: out_list }) } /// Add new edges to the graph without python data. @@ -898,14 +904,12 @@ impl PyGraph { &mut self, py: Python, obj_list: Vec<(usize, usize)>, - ) -> EdgeIndices { + ) -> PyResult { let mut out_list: Vec = Vec::with_capacity(obj_list.len()); for obj in obj_list { - let p_index = NodeIndex::new(obj.0); - let c_index = NodeIndex::new(obj.1); - out_list.push(self._add_edge(p_index, c_index, py.None())); + out_list.push(self.add_edge(obj.0, obj.1, py.None())?); } - EdgeIndices { edges: out_list } + Ok(EdgeIndices { edges: out_list }) } /// Extend graph from an edge list @@ -1703,7 +1707,7 @@ impl PyGraph { } for (source, weight) in edges { - self.add_edge(source.index(), node_index.index(), weight); + self.add_edge(source.index(), node_index.index(), weight)?; } Ok(node_index.index()) diff --git a/src/tree.rs b/src/tree.rs index 11e2ba5b6..b9426346c 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -130,7 +130,7 @@ pub fn minimum_spanning_tree( .edges .iter() { - spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py)); + spanning_tree.add_edge(edge.0, edge.1, edge.2.clone_ref(py))?; } Ok(spanning_tree) diff --git a/tests/rustworkx_tests/digraph/test_edges.py b/tests/rustworkx_tests/digraph/test_edges.py index 2d9f56ae5..ef0af66dd 100644 --- a/tests/rustworkx_tests/digraph/test_edges.py +++ b/tests/rustworkx_tests/digraph/test_edges.py @@ -963,6 +963,21 @@ def test_extend_from_weighted_edge_list(self): self.assertEqual(len(graph), 4) self.assertEqual(["a", "b", "c", "d", "e"], graph.edges()) + def test_add_edge_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edge(2, 3, None) + + def test_add_edges_from_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edges_from([(2, 3, 5)]) + + def test_add_edges_from_no_data_non_existent(self): + g = rustworkx.PyDiGraph() + with self.assertRaises(IndexError): + g.add_edges_from_no_data([(2, 3)]) + def test_reverse_graph(self): graph = rustworkx.PyDiGraph() graph.add_nodes_from([i for i in range(4)]) @@ -978,7 +993,7 @@ def test_reverse_graph(self): self.assertEqual([(1, 0), (2, 1), (2, 0), (3, 2), (3, 0)], graph.edge_list()) def test_reverse_large_graph(self): - LARGE_AMOUNT_OF_NODES = 10000000 + LARGE_AMOUNT_OF_NODES = 1000000 graph = rustworkx.PyDiGraph() graph.add_nodes_from(range(LARGE_AMOUNT_OF_NODES)) diff --git a/tests/rustworkx_tests/graph/test_edges.py b/tests/rustworkx_tests/graph/test_edges.py index 4981225a5..04f24af1a 100644 --- a/tests/rustworkx_tests/graph/test_edges.py +++ b/tests/rustworkx_tests/graph/test_edges.py @@ -817,3 +817,18 @@ def test_extend_from_weighted_edge_list(self): graph.extend_from_weighted_edge_list(edge_list) self.assertEqual(len(graph), 4) self.assertEqual(["a", "b", "c", "d", "e"], graph.edges()) + + def test_add_edge_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edge(2, 3, None) + + def test_add_edges_from_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edges_from([(2, 3, 5)]) + + def test_add_edges_from_no_data_non_existent(self): + g = rustworkx.PyGraph() + with self.assertRaises(IndexError): + g.add_edges_from_no_data([(2, 3)]) From c1806dd8a86e63274d9f6d7ed5d4361f96ecc13a Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Sun, 14 May 2023 18:17:51 -0400 Subject: [PATCH 10/37] Fix incorrect test case (#872) Closes #798 I generated the values the test case is checking against using NetworkX. So the implementation is correct, it was just the test case and the assert_almost_equal that were incorrect --- rustworkx-core/src/centrality.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index d454741a9..2866a6d0e 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -681,7 +681,7 @@ mod test_eigenvector_centrality { macro_rules! assert_almost_equal { ($x:expr, $y:expr, $d:expr) => { - if !($x - $y < $d || $y - $x < $d) { + if ($x - $y).abs() >= $d { panic!("{} != {} within delta of {}", $x, $y, $d); } }; @@ -753,8 +753,7 @@ mod test_eigenvector_centrality { let output: Result>> = eigenvector_centrality(&g, |_| Ok(2.), None, None); let result = output.unwrap().unwrap(); let expected_values: Vec = vec![ - 0.25368793, 0.19576478, 0.32817092, 0.40430835, 0.48199885, 0.15724483, 0.51346196, - 0.32475403, + 0.2140437, 0.2009269, 0.1036383, 0.0972886, 0.3113323, 0.4891686, 0.4420605, 0.6016448, ]; for i in 0..8 { assert_almost_equal!(expected_values[i], result[i], 1e-4); From ae95d1d412abba2e0cf5021fe486859b915f54f8 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 15 May 2023 11:26:27 -0400 Subject: [PATCH 11/37] Add PyDiGraph method to make edges symmetric (#814) * Add PyDiGraph method to make edges symmetric This commit adds a new method make_symmetric() to PyDiGraph which will modify the graph and add a reverse edge to each edge in the graph if it is not already present. * Simplyify logic * Add initial tests * Update docstring wording Co-authored-by: John Lapeyre * Add release notes * Expand testing * Fix tests and cycle checking * Remove stray debug prints * Update src/digraph.rs Co-authored-by: John Lapeyre * Fix lint --------- Co-authored-by: John Lapeyre Co-authored-by: Edwin Navarro --- ...graph-make-symmetric-60d0287a7f7eec04.yaml | 16 ++++ src/digraph.rs | 37 +++++++++ .../rustworkx_tests/digraph/test_symmetric.py | 83 +++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml diff --git a/releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml b/releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml new file mode 100644 index 000000000..ce24c0f66 --- /dev/null +++ b/releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml @@ -0,0 +1,16 @@ +--- +features: + - | + Added a new method, :meth:`~.PyDiGraph.make_symmetric`, to the + :class:`~.PyDiGraph` class. This method is used to make all the edges + in the graph symmetric (there is a reverse edge in the graph for each edge). + For example: + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import graphviz_draw + + graph = rx.generators.directed_path_graph(5, bidirectional=False) + graph.make_symmetric() + graphviz_draw(graph) diff --git a/src/digraph.rs b/src/digraph.rs index e0a830225..a0d63b09c 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2712,6 +2712,43 @@ impl PyDiGraph { edges.is_empty() } + /// Make edges in graph symmetric + /// + /// This function iterates over all the edges in the graph, adding for each + /// edge the reversed edge, unless one is already present. Note the edge insertion + /// is not fixed and the edge indices are not guaranteed to be consistent + /// between executions of this method on identical graphs. + /// + /// :param callable edge_payload: This optional argument takes in a callable which will + /// be passed a single positional argument the data payload for an edge that will + /// have a reverse copied in the graph. The returned value from this callable will + /// be used as the data payload for the new edge created. If this is not specified + /// then by default the data payload will be copied when the reverse edge is added. + /// If there are parallel edges, then one of the edges (typically the one with the lower + /// index, but this is not a guarantee) will be copied. + pub fn make_symmetric( + &mut self, + py: Python, + edge_payload_fn: Option, + ) -> PyResult<()> { + let edges: HashMap<[NodeIndex; 2], EdgeIndex> = self + .graph + .edge_references() + .map(|edge| ([edge.source(), edge.target()], edge.id())) + .collect(); + for ([edge_source, edge_target], edge_index) in edges.iter() { + if !edges.contains_key(&[*edge_target, *edge_source]) { + let forward_weight = self.graph.edge_weight(*edge_index).unwrap(); + let weight: PyObject = match edge_payload_fn.as_ref() { + Some(callback) => callback.call1(py, (forward_weight,))?, + None => forward_weight.clone_ref(py), + }; + self._add_edge(*edge_target, *edge_source, weight)?; + } + } + Ok(()) + } + /// Generate a new PyGraph object from this graph /// /// This will create a new :class:`~rustworkx.PyGraph` object from this diff --git a/tests/rustworkx_tests/digraph/test_symmetric.py b/tests/rustworkx_tests/digraph/test_symmetric.py index 67f336d9f..3c29f0d55 100644 --- a/tests/rustworkx_tests/digraph/test_symmetric.py +++ b/tests/rustworkx_tests/digraph/test_symmetric.py @@ -15,6 +15,10 @@ import rustworkx +def default_weight_function(edge): + return "Reversi" + + class TestSymmetric(unittest.TestCase): def test_single_neighbor(self): digraph = rustworkx.PyDiGraph() @@ -37,3 +41,82 @@ def test_bidirectional_ring(self): ] digraph.extend_from_edge_list(edge_list) self.assertTrue(digraph.is_symmetric()) + + def test_empty_graph_make_symmetric(self): + digraph = rustworkx.PyDiGraph() + digraph.make_symmetric() + self.assertEqual(0, digraph.num_edges()) + self.assertEqual(0, digraph.num_nodes()) + + def test_path_graph_make_symmetric(self): + digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False) + digraph.make_symmetric() + expected_edge_list = { + (0, 1), + (1, 2), + (2, 3), + (1, 0), + (2, 1), + (3, 2), + } + self.assertEqual(set(digraph.edge_list()), expected_edge_list) + + def test_path_graph_make_symmetric_existing_reverse_edges(self): + digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False) + digraph.add_edge(3, 2, None) + digraph.add_edge(1, 0, None) + digraph.make_symmetric() + expected_edge_list = { + (0, 1), + (1, 2), + (2, 3), + (3, 2), + (1, 0), + (2, 1), + } + self.assertEqual(set(digraph.edge_list()), expected_edge_list) + + def test_empty_graph_make_symmetric_with_function_arg(self): + digraph = rustworkx.PyDiGraph() + digraph.make_symmetric(default_weight_function) + self.assertEqual(0, digraph.num_edges()) + self.assertEqual(0, digraph.num_nodes()) + + def test_path_graph_make_symmetric_with_function_arg(self): + digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False) + digraph.make_symmetric(default_weight_function) + expected_edge_list = { + (0, 1, None), + (1, 2, None), + (2, 3, None), + (1, 0, "Reversi"), + (2, 1, "Reversi"), + (3, 2, "Reversi"), + } + result = set(digraph.weighted_edge_list()) + self.assertEqual(result, expected_edge_list) + + def test_path_graph_make_symmetric_existing_reverse_edges_function_arg(self): + digraph = rustworkx.generators.directed_path_graph(4, bidirectional=False) + digraph.add_edge(3, 2, None) + digraph.add_edge(1, 0, None) + digraph.make_symmetric(default_weight_function) + expected_edge_list = { + (0, 1, None), + (1, 2, None), + (2, 3, None), + (3, 2, None), + (1, 0, None), + (2, 1, "Reversi"), + } + self.assertEqual(set(digraph.weighted_edge_list()), expected_edge_list) + + def test_path_graph_make_symmetric_function_arg_raises(self): + digraph = rustworkx.generators.directed_path_graph(4) + + def weight_function(edge): + if edge is None: + raise TypeError("I'm expected") + + with self.assertRaises(TypeError): + digraph.make_symmetric(weight_function) From 1c6c7963d5cea3f988e308b841f88da2d888535d Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 16 May 2023 20:16:32 -0400 Subject: [PATCH 12/37] Improve documentation for graph_greedy_color (#857) * Improve documentation for graph_greedy_color This commit improves the documentation for the graph_greedy_color function. Previously, the details on the function and the algorithm it implemented where a bit sparse. This commit expands it by explaining the source for the algorithm, making it clear it's not always going to return an optimal solution, and also adding an example. * Update src/coloring.rs * Change to use note directive * Update src/coloring.rs Co-authored-by: Julien Gacon --------- Co-authored-by: Julien Gacon --- src/coloring.rs | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/coloring.rs b/src/coloring.rs index 329eb8385..b432ee1d3 100644 --- a/src/coloring.rs +++ b/src/coloring.rs @@ -26,13 +26,38 @@ use petgraph::visit::NodeCount; use rayon::prelude::*; -/// Color a PyGraph using a largest_first strategy greedy graph coloring. +/// Color a :class:`~.PyGraph` object using a greedy graph coloring algorithm. +/// +/// This function uses a `largest-first` strategy as described in [1]_ and colors +/// the nodes with higher degree first. +/// +/// .. note:: +/// +/// The coloring problem is NP-hard and this is a heuristic algorithm which +/// may not return an optimal solution. /// /// :param PyGraph: The input PyGraph object to color /// /// :returns: A dictionary where keys are node indices and the value is /// the color /// :rtype: dict +/// +/// .. jupyter-execute:: +/// +/// import rustworkx as rx +/// from rustworkx.visualization import mpl_draw +/// +/// graph = rx.generators.generalized_petersen_graph(5, 2) +/// coloring = rx.graph_greedy_color(graph) +/// colors = [coloring[node] for node in graph.node_indices()] +/// +/// # Draw colored graph +/// layout = rx.shell_layout(graph, nlist=[[0, 1, 2, 3, 4],[6, 7, 8, 9, 5]]) +/// mpl_draw(graph, node_color=colors, pos=layout) +/// +/// +/// .. [1] Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs, +/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4. #[pyfunction] #[pyo3(text_signature = "(graph, /)")] pub fn graph_greedy_color(py: Python, graph: &graph::PyGraph) -> PyResult { From 57f873f833530ee999d87dfbf285813d20d13928 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Tue, 16 May 2023 23:53:07 -0400 Subject: [PATCH 13/37] Switch to new documentation links to avoid a redirect (#877) * Switch old documentation links to new one (saves a redirect) * More links --- README.md | 2 +- docs/source/conf.py | 4 ++-- rustworkx-core/src/lib.rs | 2 +- setup.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5511a0fef..899fedc76 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ [![Zenodo](https://img.shields.io/badge/Zenodo-10.5281%2Fzenodo.5879859-blue)](https://doi.org/10.5281/zenodo.5879859) - You can see the full rendered docs at: - + |:warning:| The retworkx project has been renamed to **rustworkx**. The use of the retworkx package will still work for the time being but starting in the 1.0.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index 61e13c1b8..4e8977613 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -126,8 +126,8 @@ redirects[f"stubs/{source_str}"] = f"../apiref/{source_str}" if os.getenv("RETWORKX_LEGACY_DOCS", None) is not None: - redirects["*"] = "https://qiskit.org/documentation/rustworkx/$source.html" - html_baseurl = "https://qiskit.org/documentation/rustworkx/" + redirects["*"] = "https://qiskit.org/ecosystem/rustworkx/$source.html" + html_baseurl = "https://qiskit.org/ecosystem/rustworkx/" # Version extensions diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index ab54ad5dc..d1c7e72ef 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -59,7 +59,7 @@ //! The release notes for rustworkx-core are included as part of the rustworkx //! documentation which is hosted at: //! -//! +//! use std::convert::Infallible; diff --git a/setup.py b/setup.py index e0c73be2e..02d3ff6a5 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ def readme(): project_urls={ "Bug Tracker": "https://github.com/Qiskit/rustworkx/issues", "Source Code": "https://github.com/Qiskit/rustworkx", - "Documentation": "https://qiskit.org/documentation/rustworkx", + "Documentation": "https://qiskit.org/ecosystem/rustworkx/", }, rust_extensions=RUST_EXTENSIONS, include_package_data=True, From 3767c3727b9b8390dbd6f05a369343e359b0cd61 Mon Sep 17 00:00:00 2001 From: Edwin Navarro Date: Wed, 17 May 2023 10:54:34 -0700 Subject: [PATCH 14/37] Add random generators to rustworkx-core (#818) * Initial gnp random * gnp and gnm graphs * First tests and connect to python * Finish gnp and gnm * Finish testing and reno * Cleanup * Fix random tests and add to release note * Fix token_swapper test * Fix node payload and other minor stuff * Lint * Minor format * Test fixes --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --- ...dd-random-generators-9f99e57b5e4188f2.yaml | 18 + rustworkx-core/src/generators/mod.rs | 4 + rustworkx-core/src/generators/random_graph.rs | 577 ++++++++++++++++++ src/random_graph.rs | 309 +++------- .../retworkx_backwards_compat/test_random.py | 14 +- tests/rustworkx_tests/test_random.py | 30 +- 6 files changed, 725 insertions(+), 227 deletions(-) create mode 100644 releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml create mode 100644 rustworkx-core/src/generators/random_graph.rs diff --git a/releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml b/releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml new file mode 100644 index 000000000..7e9d5cba2 --- /dev/null +++ b/releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml @@ -0,0 +1,18 @@ +--- +features: + - | + Three new random graph generators, ``gnp_random_graph``, ``gnm_random_graph`` + and ``random_geometric_graph``, have been added to the ``rustworkx-core`` + crate in the ``generators`` module. The ``gnp_random_graph`` takes inputs of + the number of nodes and a probability for adding edges. The ``gnp_random_graph`` + takes inputs of the number of nodes and number of edges. The + ``random_geometric_graph`` creates a random graph within an n-dimensional + cube. +upgrade: + - | + Passing a negative value to the ``probability`` argument to the + :func:`~rustworkx.gnp_directed_random_graph` or the + :func:`~rustworkx.gnp_undirected_random_graph` function will now cause + an ``OverflowError`` to be raised. Previously, a ``ValueError`` would be + raised in this situation. This was changed to be consistent with other similar + error conditions in other functions in the library. diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index 40237685d..1b1b82f80 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -24,6 +24,7 @@ mod hexagonal_lattice_graph; mod lollipop_graph; mod path_graph; mod petersen_graph; +mod random_graph; mod star_graph; mod utils; @@ -55,4 +56,7 @@ pub use hexagonal_lattice_graph::hexagonal_lattice_graph; pub use lollipop_graph::lollipop_graph; pub use path_graph::path_graph; pub use petersen_graph::petersen_graph; +pub use random_graph::gnm_random_graph; +pub use random_graph::gnp_random_graph; +pub use random_graph::random_geometric_graph; pub use star_graph::star_graph; diff --git a/rustworkx-core/src/generators/random_graph.rs b/rustworkx-core/src/generators/random_graph.rs new file mode 100644 index 000000000..235adbb71 --- /dev/null +++ b/rustworkx-core/src/generators/random_graph.rs @@ -0,0 +1,577 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +#![allow(clippy::float_cmp)] + +use std::hash::Hash; + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, NodeIndexable}; + +use rand::distributions::{Distribution, Uniform}; +use rand::prelude::*; +use rand_pcg::Pcg64; + +use super::InvalidInputError; + +/// Generate a Gnp random graph, also known as an +/// Erdős-Rényi graph or a binomial graph. +/// +/// For number of nodes `n` and probability `p`, the Gnp +/// graph algorithm creates `n` nodes, and for all the `n * (n - 1)` possible edges, +/// each edge is created independently with probability `p`. +/// In general, for any probability `p`, the expected number of edges returned +/// is `m = p * n * (n - 1)`. If `p = 0` or `p = 1`, the returned +/// graph is not random and will always be an empty or a complete graph respectively. +/// An empty graph has zero edges and a complete directed graph has `n (n - 1)` edges. +/// The run time is `O(n + m)` where `m` is the expected number of edges mentioned above. +/// When `p = 0`, run time always reduces to `O(n)`, as the lower bound. +/// When `p = 1`, run time always goes to `O(n + n * (n - 1))`, as the upper bound. +/// +/// For `0 < p < 1`, the algorithm is based on the implementation of the networkx function +/// ``fast_gnp_random_graph``, +/// +/// +/// Vladimir Batagelj and Ulrik Brandes, +/// "Efficient generation of large random networks", +/// Phys. Rev. E, 71, 036113, 2005. +/// +/// Arguments: +/// +/// * `num_nodes` - The number of nodes for creating the random graph. +/// * `probability` - The probability of creating an edge between two nodes as a float. +/// * `seed` - An optional seed to use for the random number generator. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::gnp_random_graph; +/// +/// let g: petgraph::graph::DiGraph<(), ()> = gnp_random_graph( +/// 20, +/// 1.0, +/// None, +/// || {()}, +/// || {()}, +/// ).unwrap(); +/// assert_eq!(g.node_count(), 20); +/// assert_eq!(g.edge_count(), 20 * (20 - 1)); +/// ``` +pub fn gnp_random_graph( + num_nodes: usize, + probability: f64, + seed: Option, + mut default_node_weight: F, + mut default_edge_weight: H, +) -> Result +where + G: Build + Create + Data + NodeIndexable + GraphProp, + F: FnMut() -> T, + H: FnMut() -> M, + G::NodeId: Eq + Hash, +{ + if num_nodes == 0 { + return Err(InvalidInputError {}); + } + let mut rng: Pcg64 = match seed { + Some(seed) => Pcg64::seed_from_u64(seed), + None => Pcg64::from_entropy(), + }; + let mut graph = G::with_capacity(num_nodes, num_nodes); + let directed = graph.is_directed(); + + for _ in 0..num_nodes { + graph.add_node(default_node_weight()); + } + if !(0.0..=1.0).contains(&probability) { + return Err(InvalidInputError {}); + } + if probability > 0.0 { + if (probability - 1.0).abs() < std::f64::EPSILON { + for u in 0..num_nodes { + let start_node = if directed { 0 } else { u + 1 }; + for v in start_node..num_nodes { + if !directed || u != v { + // exclude self-loops + let u_index = graph.from_index(u); + let v_index = graph.from_index(v); + graph.add_edge(u_index, v_index, default_edge_weight()); + } + } + } + } else { + let mut v: isize = if directed { 0 } else { 1 }; + let mut w: isize = -1; + let num_nodes: isize = num_nodes as isize; + let lp: f64 = (1.0 - probability).ln(); + + let between = Uniform::new(0.0, 1.0); + while v < num_nodes { + let random: f64 = between.sample(&mut rng); + let lr: f64 = (1.0 - random).ln(); + let ratio: isize = (lr / lp) as isize; + w = w + 1 + ratio; + + if directed { + // avoid self loops + if v == w { + w += 1; + } + } + while v < num_nodes && ((directed && num_nodes <= w) || (!directed && v <= w)) { + w -= v; + v += 1; + // avoid self loops + if directed && v == w { + w -= v; + v += 1; + } + } + if v < num_nodes { + let v_index = graph.from_index(v as usize); + let w_index = graph.from_index(w as usize); + graph.add_edge(v_index, w_index, default_edge_weight()); + } + } + } + } + Ok(graph) +} + +// /// Return a `G_{nm}` directed graph, also known as an +// /// Erdős-Rényi graph. +// /// +// /// Generates a random directed graph out of all the possible graphs with `n` nodes and +// /// `m` edges. The generated graph will not be a multigraph and will not have self loops. +// /// +// /// For `n` nodes, the maximum edges that can be returned is `n (n - 1)`. +// /// Passing `m` higher than that will still return the maximum number of edges. +// /// If `m = 0`, the returned graph will always be empty (no edges). +// /// When a seed is provided, the results are reproducible. Passing a seed when `m = 0` +// /// or `m >= n (n - 1)` has no effect, as the result will always be an empty or a complete graph respectively. +// /// +// /// This algorithm has a time complexity of `O(n + m)` + +/// Generate a Gnm random graph, also known as an +/// Erdős-Rényi graph. +/// +/// Generates a random directed graph out of all the possible graphs with `n` nodes and +/// `m` edges. The generated graph will not be a multigraph and will not have self loops. +/// +/// For `n` nodes, the maximum edges that can be returned is `n * (n - 1)`. +/// Passing `m` higher than that will still return the maximum number of edges. +/// If `m = 0`, the returned graph will always be empty (no edges). +/// When a seed is provided, the results are reproducible. Passing a seed when `m = 0` +/// or `m >= n * (n - 1)` has no effect, as the result will always be an empty or a +/// complete graph respectively. +/// +/// This algorithm has a time complexity of `O(n + m)` +/// +/// Arguments: +/// +/// * `num_nodes` - The number of nodes to create in the graph. +/// * `num_edges` - The number of edges to create in the graph. +/// * `seed` - An optional seed to use for the random number generator. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::gnm_random_graph; +/// +/// let g: petgraph::graph::DiGraph<(), ()> = gnm_random_graph( +/// 20, +/// 12, +/// None, +/// || {()}, +/// || {()}, +/// ).unwrap(); +/// assert_eq!(g.node_count(), 20); +/// assert_eq!(g.edge_count(), 12); +/// ``` +pub fn gnm_random_graph( + num_nodes: usize, + num_edges: usize, + seed: Option, + mut default_node_weight: F, + mut default_edge_weight: H, +) -> Result +where + G: GraphProp + Build + Create + Data + NodeIndexable, + F: FnMut() -> T, + H: FnMut() -> M, + for<'b> &'b G: GraphBase + IntoEdgeReferences, + G::NodeId: Eq + Hash, +{ + if num_nodes == 0 { + return Err(InvalidInputError {}); + } + + fn find_edge(graph: &G, source: usize, target: usize) -> bool + where + G: GraphBase + NodeIndexable, + for<'b> &'b G: GraphBase + IntoEdgeReferences, + { + let mut found = false; + for edge in graph.edge_references() { + if graph.to_index(edge.source()) == source && graph.to_index(edge.target()) == target { + found = true; + break; + } + } + found + } + + let mut rng: Pcg64 = match seed { + Some(seed) => Pcg64::seed_from_u64(seed), + None => Pcg64::from_entropy(), + }; + let mut graph = G::with_capacity(num_nodes, num_edges); + let directed = graph.is_directed(); + + for _ in 0..num_nodes { + graph.add_node(default_node_weight()); + } + // if number of edges to be created is >= max, + // avoid randomly missed trials and directly add edges between every node + let div_by = if directed { 1 } else { 2 }; + if num_edges >= num_nodes * (num_nodes - 1) / div_by { + for u in 0..num_nodes { + let start_node = if directed { 0 } else { u + 1 }; + for v in start_node..num_nodes { + // avoid self-loops + if !directed || u != v { + let u_index = graph.from_index(u); + let v_index = graph.from_index(v); + graph.add_edge(u_index, v_index, default_edge_weight()); + } + } + } + } else { + let mut created_edges: usize = 0; + let between = Uniform::new(0, num_nodes); + while created_edges < num_edges { + let u = between.sample(&mut rng); + let v = between.sample(&mut rng); + let u_index = graph.from_index(u); + let v_index = graph.from_index(v); + // avoid self-loops and multi-graphs + if u != v && !find_edge(&graph, u, v) { + graph.add_edge(u_index, v_index, default_edge_weight()); + created_edges += 1; + } + } + } + Ok(graph) +} + +#[inline] +fn pnorm(x: f64, p: f64) -> f64 { + if p == 1.0 || p == std::f64::INFINITY { + x.abs() + } else if p == 2.0 { + x * x + } else { + x.abs().powf(p) + } +} + +fn distance(x: &[f64], y: &[f64], p: f64) -> f64 { + let it = x.iter().zip(y.iter()).map(|(xi, yi)| pnorm(xi - yi, p)); + + if p == std::f64::INFINITY { + it.fold(-1.0, |max, x| if x > max { x } else { max }) + } else { + it.sum() + } +} + +/// Generate a random geometric graph in the unit cube of dimensions `dim`. +/// +/// The random geometric graph model places `num_nodes` nodes uniformly at +/// random in the unit cube. Two nodes are joined by an edge if the +/// distance between the nodes is at most `radius`. +/// +/// Each node has a node attribute ``'pos'`` that stores the +/// position of that node in Euclidean space as provided by the +/// ``pos`` keyword argument or, if ``pos`` was not provided, as +/// generated by this function. +/// +/// Arguments +/// +/// * `num_nodes` - The number of nodes to create in the graph. +/// * `radius` - Distance threshold value. +/// * `dim` - Dimension of node positions. Default: 2 +/// * `pos` - Optional list with node positions as values. +/// * `p` - Which Minkowski distance metric to use. `p` has to meet the condition +/// ``1 <= p <= infinity``. +/// If this argument is not specified, the L2 metric +/// (the Euclidean distance metric), `p = 2` is used. +/// * `seed` - An optional seed to use for the random number generator. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::random_geometric_graph; +/// +/// let g: petgraph::graph::UnGraph, ()> = random_geometric_graph( +/// 10, +/// 1.42, +/// 2, +/// None, +/// 2.0, +/// None, +/// || {()}, +/// ).unwrap(); +/// assert_eq!(g.node_count(), 10); +/// assert_eq!(g.edge_count(), 45); +/// ``` +pub fn random_geometric_graph( + num_nodes: usize, + radius: f64, + dim: usize, + pos: Option>>, + p: f64, + seed: Option, + mut default_edge_weight: H, +) -> Result +where + G: GraphBase + Build + Create + Data, EdgeWeight = M> + NodeIndexable, + H: FnMut() -> M, + for<'b> &'b G: GraphBase + IntoEdgeReferences, + G::NodeId: Eq + Hash, +{ + if num_nodes == 0 { + return Err(InvalidInputError {}); + } + let mut rng: Pcg64 = match seed { + Some(seed) => Pcg64::seed_from_u64(seed), + None => Pcg64::from_entropy(), + }; + let mut graph = G::with_capacity(num_nodes, num_nodes); + + let radius_p = pnorm(radius, p); + let dist = Uniform::new(0.0, 1.0); + let pos = pos.unwrap_or_else(|| { + (0..num_nodes) + .map(|_| (0..dim).map(|_| dist.sample(&mut rng)).collect()) + .collect() + }); + if num_nodes != pos.len() { + return Err(InvalidInputError {}); + } + for pval in pos.iter() { + graph.add_node(pval.clone()); + } + for u in 0..(num_nodes - 1) { + for v in (u + 1)..num_nodes { + if distance(&pos[u], &pos[v], p) < radius_p { + graph.add_edge( + graph.from_index(u), + graph.from_index(v), + default_edge_weight(), + ); + } + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::InvalidInputError; + use crate::generators::{gnm_random_graph, gnp_random_graph, random_geometric_graph}; + use crate::petgraph; + + // Test gnp_random_graph + + #[test] + fn test_gnp_random_graph_directed() { + let g: petgraph::graph::DiGraph<(), ()> = + gnp_random_graph(20, 0.5, Some(10), || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 104); + } + + #[test] + fn test_gnp_random_graph_directed_empty() { + let g: petgraph::graph::DiGraph<(), ()> = + gnp_random_graph(20, 0.0, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 0); + } + + #[test] + fn test_gnp_random_graph_directed_complete() { + let g: petgraph::graph::DiGraph<(), ()> = + gnp_random_graph(20, 1.0, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 20 * (20 - 1)); + } + + #[test] + fn test_gnp_random_graph_undirected() { + let g: petgraph::graph::UnGraph<(), ()> = + gnp_random_graph(20, 0.5, Some(10), || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 105); + } + + #[test] + fn test_gnp_random_graph_undirected_empty() { + let g: petgraph::graph::UnGraph<(), ()> = + gnp_random_graph(20, 0.0, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 0); + } + + #[test] + fn test_gnp_random_graph_undirected_complete() { + let g: petgraph::graph::UnGraph<(), ()> = + gnp_random_graph(20, 1.0, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 20 * (20 - 1) / 2); + } + + #[test] + fn test_gnp_random_graph_error() { + match gnp_random_graph::, (), _, _, ()>( + 0, + 3.0, + None, + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + // Test gnm_random_graph + + #[test] + fn test_gnm_random_graph_directed() { + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(20, 100, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 100); + } + + #[test] + fn test_gnm_random_graph_directed_empty() { + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(20, 0, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 0); + } + + #[test] + fn test_gnm_random_graph_directed_complete() { + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(20, 20 * (20 - 1), None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 20 * (20 - 1)); + } + + #[test] + fn test_gnm_random_graph_directed_max_edges() { + let n = 20; + let max_m = n * (n - 1); + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(n, max_m, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), n); + assert_eq!(g.edge_count(), max_m); + // passing the max edges for the passed number of nodes + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(n, max_m + 1, None, || (), || ()).unwrap(); + assert_eq!(g.node_count(), n); + assert_eq!(g.edge_count(), max_m); + // passing a seed when passing max edges has no effect + let g: petgraph::graph::DiGraph<(), ()> = + gnm_random_graph(n, max_m, Some(55), || (), || ()).unwrap(); + assert_eq!(g.node_count(), n); + assert_eq!(g.edge_count(), max_m); + } + + #[test] + fn test_gnm_random_graph_error() { + match gnm_random_graph::, (), _, _, ()>( + 0, + 0, + None, + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + // Test random_geometric_graph + + #[test] + fn test_random_geometric_empty() { + let g: petgraph::graph::UnGraph, ()> = + random_geometric_graph(20, 0.0, 2, None, 2.0, None, || ()).unwrap(); + assert_eq!(g.node_count(), 20); + assert_eq!(g.edge_count(), 0); + } + + #[test] + fn test_random_geometric_complete() { + let g: petgraph::graph::UnGraph, ()> = + random_geometric_graph(10, 1.42, 2, None, 2.0, None, || ()).unwrap(); + assert_eq!(g.node_count(), 10); + assert_eq!(g.edge_count(), 45); + } + + #[test] + fn test_random_geometric_bad_num_nodes() { + match random_geometric_graph::, ()>, _, ()>( + 0, + 1.0, + 2, + None, + 2.0, + None, + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } + + #[test] + fn test_random_geometric_bad_pos() { + match random_geometric_graph::, ()>, _, ()>( + 3, + 0.15, + 3, + Some(vec![vec![0.5, 0.5]]), + 2.0, + None, + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } +} diff --git a/src/random_graph.rs b/src/random_graph.rs index 408544064..c9e830686 100644 --- a/src/random_graph.rs +++ b/src/random_graph.rs @@ -27,6 +27,8 @@ use rand::distributions::{Distribution, Uniform}; use rand::prelude::*; use rand_pcg::Pcg64; +use rustworkx_core::generators as core_generators; + /// Return a :math:`G_{np}` directed random graph, also known as an /// Erdős-Rényi graph or a binomial graph. /// @@ -57,83 +59,42 @@ use rand_pcg::Pcg64; /// Phys. Rev. E, 71, 036113, 2005. /// .. [2] https://github.com/networkx/networkx/blob/networkx-2.4/networkx/generators/random_graphs.py#L49-L120 #[pyfunction] -#[pyo3(text_signature = "(num_nodes, probability, seed=None, /)")] +#[pyo3(text_signature = "(num_nodes, probability, /, seed=None)")] pub fn directed_gnp_random_graph( py: Python, - num_nodes: isize, + num_nodes: usize, probability: f64, seed: Option, ) -> PyResult { - if num_nodes <= 0 { - return Err(PyValueError::new_err("num_nodes must be > 0")); - } - let mut rng: Pcg64 = match seed { - Some(seed) => Pcg64::seed_from_u64(seed), - None => Pcg64::from_entropy(), - }; - let mut inner_graph = StablePyGraph::::new(); - for x in 0..num_nodes { - inner_graph.add_node(x.to_object(py)); - } - if !(0.0..=1.0).contains(&probability) { - return Err(PyValueError::new_err( - "Probability out of range, must be 0 <= p <= 1", - )); - } - if probability > 0.0 { - if (probability - 1.0).abs() < std::f64::EPSILON { - for u in 0..num_nodes { - for v in 0..num_nodes { - if u != v { - // exclude self-loops - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - inner_graph.add_edge(u_index, v_index, py.None()); - } - } - } - } else { - let mut v: isize = 0; - let mut w: isize = -1; - let lp: f64 = (1.0 - probability).ln(); - - let between = Uniform::new(0.0, 1.0); - while v < num_nodes { - let random: f64 = between.sample(&mut rng); - let lr: f64 = (1.0 - random).ln(); - let ratio: isize = (lr / lp) as isize; - w = w + 1 + ratio; - // avoid self loops - if v == w { - w += 1; - } - while v < num_nodes && num_nodes <= w { - w -= v; - v += 1; - // avoid self loops - if v == w { - w -= v; - v += 1; - } - } - if v < num_nodes { - let v_index = NodeIndex::new(v as usize); - let w_index = NodeIndex::new(w as usize); - inner_graph.add_edge(v_index, w_index, py.None()); - } - } + let default_fn = || py.None(); + let mut graph: StablePyGraph = match core_generators::gnp_random_graph( + num_nodes, + probability, + seed, + default_fn, + default_fn, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "num_nodes or probability invalid input", + )) } + }; + // Core function does not put index into node payload, so for backwards compat + // in the python interface, we do it here. + let nodes: Vec = graph.node_indices().collect(); + for node in nodes.iter() { + graph[*node] = node.index().to_object(py); } - - let graph = digraph::PyDiGraph { - graph: inner_graph, - cycle_state: algo::DfsSpace::default(), - check_cycle: false, + Ok(digraph::PyDiGraph { + graph, node_removed: false, - multigraph: true, + check_cycle: false, + cycle_state: algo::DfsSpace::default(), + multigraph: false, attrs: py.None(), - }; - Ok(graph) + }) } /// Return a :math:`G_{np}` random undirected graph, also known as an @@ -166,69 +127,40 @@ pub fn directed_gnp_random_graph( /// Phys. Rev. E, 71, 036113, 2005. /// .. [2] https://github.com/networkx/networkx/blob/networkx-2.4/networkx/generators/random_graphs.py#L49-L120 #[pyfunction] -#[pyo3(text_signature = "(num_nodes, probability, seed=None, /)")] +#[pyo3(text_signature = "(num_nodes, probability, /, seed=None)")] pub fn undirected_gnp_random_graph( py: Python, - num_nodes: isize, + num_nodes: usize, probability: f64, seed: Option, ) -> PyResult { - if num_nodes <= 0 { - return Err(PyValueError::new_err("num_nodes must be > 0")); - } - let mut rng: Pcg64 = match seed { - Some(seed) => Pcg64::seed_from_u64(seed), - None => Pcg64::from_entropy(), - }; - let mut inner_graph = StablePyGraph::::default(); - for x in 0..num_nodes { - inner_graph.add_node(x.to_object(py)); - } - if !(0.0..=1.0).contains(&probability) { - return Err(PyValueError::new_err( - "Probability out of range, must be 0 <= p <= 1", - )); - } - if probability > 0.0 { - if (probability - 1.0).abs() < std::f64::EPSILON { - for u in 0..num_nodes { - for v in u + 1..num_nodes { - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - inner_graph.add_edge(u_index, v_index, py.None()); - } - } - } else { - let mut v: isize = 1; - let mut w: isize = -1; - let lp: f64 = (1.0 - probability).ln(); - - let between = Uniform::new(0.0, 1.0); - while v < num_nodes { - let random: f64 = between.sample(&mut rng); - let lr = (1.0 - random).ln(); - let ratio: isize = (lr / lp) as isize; - w = w + 1 + ratio; - while w >= v && v < num_nodes { - w -= v; - v += 1; - } - if v < num_nodes { - let v_index = NodeIndex::new(v as usize); - let w_index = NodeIndex::new(w as usize); - inner_graph.add_edge(v_index, w_index, py.None()); - } - } + let default_fn = || py.None(); + let mut graph: StablePyGraph = match core_generators::gnp_random_graph( + num_nodes, + probability, + seed, + default_fn, + default_fn, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "num_nodes or probability invalid input", + )) } + }; + // Core function does not put index into node payload, so for backwards compat + // in the python interface, we do it here. + let nodes: Vec = graph.node_indices().collect(); + for node in nodes.iter() { + graph[*node] = node.index().to_object(py); } - - let graph = graph::PyGraph { - graph: inner_graph, + Ok(graph::PyGraph { + graph, node_removed: false, multigraph: true, attrs: py.None(), - }; - Ok(graph) + }) } /// Return a :math:`G_{nm}` directed graph, also known as an @@ -256,61 +188,35 @@ pub fn undirected_gnp_random_graph( #[pyo3(text_signature = "(num_nodes, num_edges, /, seed=None)")] pub fn directed_gnm_random_graph( py: Python, - num_nodes: isize, - num_edges: isize, + num_nodes: usize, + num_edges: usize, seed: Option, ) -> PyResult { - if num_nodes <= 0 { - return Err(PyValueError::new_err("num_nodes must be > 0")); - } - if num_edges < 0 { - return Err(PyValueError::new_err("num_edges must be >= 0")); - } - let mut rng: Pcg64 = match seed { - Some(seed) => Pcg64::seed_from_u64(seed), - None => Pcg64::from_entropy(), - }; - let mut inner_graph = StablePyGraph::::new(); - for x in 0..num_nodes { - inner_graph.add_node(x.to_object(py)); - } - // if number of edges to be created is >= max, - // avoid randomly missed trials and directly add edges between every node - if num_edges >= num_nodes * (num_nodes - 1) { - for u in 0..num_nodes { - for v in 0..num_nodes { - // avoid self-loops - if u != v { - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - inner_graph.add_edge(u_index, v_index, py.None()); - } + let default_fn = || py.None(); + let mut graph: StablePyGraph = + match core_generators::gnm_random_graph(num_nodes, num_edges, seed, default_fn, default_fn) + { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "num_nodes or num_edges invalid input", + )) } - } - } else { - let mut created_edges: isize = 0; - let between = Uniform::new(0, num_nodes); - while created_edges < num_edges { - let u = between.sample(&mut rng); - let v = between.sample(&mut rng); - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - // avoid self-loops and multi-graphs - if u != v && inner_graph.find_edge(u_index, v_index).is_none() { - inner_graph.add_edge(u_index, v_index, py.None()); - created_edges += 1; - } - } + }; + // Core function does not put index into node payload, so for backwards compat + // in the python interface, we do it here. + let nodes: Vec = graph.node_indices().collect(); + for node in nodes.iter() { + graph[*node] = node.index().to_object(py); } - let graph = digraph::PyDiGraph { - graph: inner_graph, - cycle_state: algo::DfsSpace::default(), - check_cycle: false, + Ok(digraph::PyDiGraph { + graph, node_removed: false, - multigraph: true, + check_cycle: false, + cycle_state: algo::DfsSpace::default(), + multigraph: false, attrs: py.None(), - }; - Ok(graph) + }) } /// Return a :math:`G_{nm}` undirected graph, also known as an @@ -338,56 +244,33 @@ pub fn directed_gnm_random_graph( #[pyo3(text_signature = "(num_nodes, num_edges, /, seed=None)")] pub fn undirected_gnm_random_graph( py: Python, - num_nodes: isize, - num_edges: isize, + num_nodes: usize, + num_edges: usize, seed: Option, ) -> PyResult { - if num_nodes <= 0 { - return Err(PyValueError::new_err("num_nodes must be > 0")); - } - if num_edges < 0 { - return Err(PyValueError::new_err("num_edges must be >= 0")); - } - let mut rng: Pcg64 = match seed { - Some(seed) => Pcg64::seed_from_u64(seed), - None => Pcg64::from_entropy(), - }; - let mut inner_graph = StablePyGraph::::default(); - for x in 0..num_nodes { - inner_graph.add_node(x.to_object(py)); - } - // if number of edges to be created is >= max, - // avoid randomly missed trials and directly add edges between every node - if num_edges >= num_nodes * (num_nodes - 1) / 2 { - for u in 0..num_nodes { - for v in u + 1..num_nodes { - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - inner_graph.add_edge(u_index, v_index, py.None()); - } - } - } else { - let mut created_edges: isize = 0; - let between = Uniform::new(0, num_nodes); - while created_edges < num_edges { - let u = between.sample(&mut rng); - let v = between.sample(&mut rng); - let u_index = NodeIndex::new(u as usize); - let v_index = NodeIndex::new(v as usize); - // avoid self-loops and multi-graphs - if u != v && inner_graph.find_edge(u_index, v_index).is_none() { - inner_graph.add_edge(u_index, v_index, py.None()); - created_edges += 1; + let default_fn = || py.None(); + let mut graph: StablePyGraph = + match core_generators::gnm_random_graph(num_nodes, num_edges, seed, default_fn, default_fn) + { + Ok(graph) => graph, + Err(_) => { + return Err(PyValueError::new_err( + "num_nodes or num_edges invalid input", + )) } - } + }; + // Core function does not put index into node payload, so for backwards compat + // in the python interface, we do it here. + let nodes: Vec = graph.node_indices().collect(); + for node in nodes.iter() { + graph[*node] = node.index().to_object(py); } - let graph = graph::PyGraph { - graph: inner_graph, + Ok(graph::PyGraph { + graph, node_removed: false, multigraph: true, attrs: py.None(), - }; - Ok(graph) + }) } #[inline] diff --git a/tests/retworkx_backwards_compat/test_random.py b/tests/retworkx_backwards_compat/test_random.py index 59ef481f6..67a60760c 100644 --- a/tests/retworkx_backwards_compat/test_random.py +++ b/tests/retworkx_backwards_compat/test_random.py @@ -34,7 +34,7 @@ def test_random_gnp_directed_complete_graph(self): def test_random_gnp_directed_invalid_num_nodes(self): with self.assertRaises(ValueError): - retworkx.directed_gnp_random_graph(-23, 0.5) + retworkx.directed_gnp_random_graph(0, 0.5) def test_random_gnp_directed_invalid_probability(self): with self.assertRaises(ValueError): @@ -57,7 +57,7 @@ def test_random_gnp_undirected_complete_graph(self): def test_random_gnp_undirected_invalid_num_nodes(self): with self.assertRaises(ValueError): - retworkx.undirected_gnp_random_graph(-23, 0.5) + retworkx.undirected_gnp_random_graph(0, 0.5) def test_random_gnp_undirected_invalid_probability(self): with self.assertRaises(ValueError): @@ -101,10 +101,10 @@ def test_random_gnm_directed_complete_graph(self): def test_random_gnm_directed_invalid_num_nodes(self): with self.assertRaises(ValueError): - retworkx.directed_gnm_random_graph(-23, 5) + retworkx.directed_gnm_random_graph(0, 5) def test_random_gnm_directed_invalid_num_edges(self): - with self.assertRaises(ValueError): + with self.assertRaises(OverflowError): retworkx.directed_gnm_random_graph(23, -5) def test_random_gnm_undirected(self): @@ -143,10 +143,10 @@ def test_random_gnm_undirected_complete_graph(self): def test_random_gnm_undirected_invalid_num_nodes(self): with self.assertRaises(ValueError): - retworkx.undirected_gnm_random_graph(-23, 5) + retworkx.undirected_gnm_random_graph(0, 5) - def test_random_gnm_undirected_invalid_probability(self): - with self.assertRaises(ValueError): + def test_random_gnm_undirected_invalid_num_edges(self): + with self.assertRaises(OverflowError): retworkx.undirected_gnm_random_graph(23, -5) diff --git a/tests/rustworkx_tests/test_random.py b/tests/rustworkx_tests/test_random.py index 1a76b4d33..1a96b27ef 100644 --- a/tests/rustworkx_tests/test_random.py +++ b/tests/rustworkx_tests/test_random.py @@ -34,12 +34,16 @@ def test_random_gnp_directed_complete_graph(self): def test_random_gnp_directed_invalid_num_nodes(self): with self.assertRaises(ValueError): - rustworkx.directed_gnp_random_graph(-23, 0.5) + rustworkx.directed_gnp_random_graph(0, 0.5) def test_random_gnp_directed_invalid_probability(self): with self.assertRaises(ValueError): rustworkx.directed_gnp_random_graph(23, 123.5) + def test_random_gnp_directed_payload(self): + graph = rustworkx.directed_gnp_random_graph(3, 0.5) + self.assertEqual(graph.nodes(), [0, 1, 2]) + def test_random_gnp_undirected(self): graph = rustworkx.undirected_gnp_random_graph(20, 0.5, seed=10) self.assertEqual(len(graph), 20) @@ -57,12 +61,16 @@ def test_random_gnp_undirected_complete_graph(self): def test_random_gnp_undirected_invalid_num_nodes(self): with self.assertRaises(ValueError): - rustworkx.undirected_gnp_random_graph(-23, 0.5) + rustworkx.undirected_gnp_random_graph(0, 0.5) def test_random_gnp_undirected_invalid_probability(self): with self.assertRaises(ValueError): rustworkx.undirected_gnp_random_graph(23, 123.5) + def test_random_gnp_undirected_payload(self): + graph = rustworkx.undirected_gnp_random_graph(3, 0.5) + self.assertEqual(graph.nodes(), [0, 1, 2]) + class TestGNMRandomGraph(unittest.TestCase): def test_random_gnm_directed(self): @@ -101,12 +109,16 @@ def test_random_gnm_directed_complete_graph(self): def test_random_gnm_directed_invalid_num_nodes(self): with self.assertRaises(ValueError): - rustworkx.directed_gnm_random_graph(-23, 5) + rustworkx.directed_gnm_random_graph(0, 0) def test_random_gnm_directed_invalid_num_edges(self): - with self.assertRaises(ValueError): + with self.assertRaises(OverflowError): rustworkx.directed_gnm_random_graph(23, -5) + def test_random_gnm_directed_payload(self): + graph = rustworkx.directed_gnm_random_graph(3, 3) + self.assertEqual(graph.nodes(), [0, 1, 2]) + def test_random_gnm_undirected(self): graph = rustworkx.undirected_gnm_random_graph(20, 100) self.assertEqual(len(graph), 20) @@ -143,12 +155,16 @@ def test_random_gnm_undirected_complete_graph(self): def test_random_gnm_undirected_invalid_num_nodes(self): with self.assertRaises(ValueError): - rustworkx.undirected_gnm_random_graph(-23, 5) + rustworkx.undirected_gnm_random_graph(0, 5) - def test_random_gnm_undirected_invalid_probability(self): - with self.assertRaises(ValueError): + def test_random_gnm_undirected_invalid_num_edges(self): + with self.assertRaises(OverflowError): rustworkx.undirected_gnm_random_graph(23, -5) + def test_random_gnm_undirected_payload(self): + graph = rustworkx.undirected_gnm_random_graph(3, 3) + self.assertEqual(graph.nodes(), [0, 1, 2]) + class TestGeometricRandomGraph(unittest.TestCase): def test_random_geometric_empty(self): From 32ea6c76f743096ce43fb2e12eb32af11cb0476b Mon Sep 17 00:00:00 2001 From: Edwin Navarro Date: Wed, 17 May 2023 14:09:47 -0700 Subject: [PATCH 15/37] Add large random test (#878) --- rustworkx-core/src/token_swapper.rs | 76 +++++++++++++++++++---------- 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs index 469236acc..1aa84848d 100644 --- a/rustworkx-core/src/token_swapper.rs +++ b/rustworkx-core/src/token_swapper.rs @@ -580,29 +580,55 @@ mod test_token_swapper { assert_eq!(5, swaps.len()); assert_eq!(expected, new_map); } -} -// TODO: Port this test when rustworkx-core adds random graphs - -// def test_large_partial_random(self) -> None: -// """Test a random (partial) mapping on a large randomly generated graph""" -// size = 100 -// # Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. -// graph = rx.undirected_gnm_random_graph(size, size**2 // 10) -// for i in graph.node_indexes(): -// try: -// graph.remove_edge(i, i) # Remove self-loops. -// except rx.NoEdgeBetweenNodes: -// continue -// # Make sure the graph is connected by adding C_n -// graph.add_edges_from_no_data([(i, i + 1) for i in range(len(graph) - 1)]) -// swapper = ApproximateTokenSwapper(graph) # type: ApproximateTokenSwapper[int] - -// # Generate a randomized permutation. -// rand_perm = random.permutation(graph.nodes()) -// permutation = dict(zip(graph.nodes(), rand_perm)) -// mapping = dict(itertools.islice(permutation.items(), 0, size, 2)) # Drop every 2nd element. - -// out = list(swapper.map(mapping, trials=40)) -// util.swap_permutation([out], mapping, allow_missing_keys=True) -// self.assertEqual({i: i for i in mapping.values()}, mapping) + #[test] + fn test_large_partial_random() { + // Test a random (partial) mapping on a large randomly generated graph + use crate::generators::gnm_random_graph; + use rand::prelude::*; + use rand_pcg::Pcg64; + use std::iter::zip; + + let mut rng: Pcg64 = Pcg64::seed_from_u64(4); + + // Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. + let size = 100; + let mut g: petgraph::stable_graph::StableGraph<(), ()> = + gnm_random_graph(size, size.pow(2) / 10, Some(4), || (), || ()).unwrap(); + + // Remove self-loops + let nodes: Vec<_> = g.node_indices().collect(); + for node in nodes { + let edge = g.find_edge(node, node); + if edge.is_some() { + g.remove_edge(edge.unwrap()); + } + } + // Make sure the graph is connected by adding C_n + for i in 0..(g.node_count() - 1) { + g.add_edge(NodeIndex::new(i), NodeIndex::new(i + 1), ()); + } + + // Get node indices and randomly shuffle + let mut mapped_nodes: Vec = g.node_indices().map(|node| node.index()).collect(); + let nodes = mapped_nodes.clone(); + mapped_nodes.shuffle(&mut rng); + + // Zip nodes and shuffled nodes and remove every other one + let mut mapping: Vec<(usize, usize)> = zip(nodes, mapped_nodes).collect(); + mapping.retain(|(a, _)| a % 2 == 0); + + // Convert mapping to HashMap of NodeIndex's + let mapping: HashMap = mapping + .into_iter() + .map(|(a, b)| (NodeIndex::new(a), NodeIndex::new(b))) + .collect(); + let mut new_map = mapping.clone(); + let expected: HashMap = + mapping.values().map(|val| (*val, *val)).collect(); + + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + assert_eq!(expected, new_map) + } +} From 171aa31ff0b0202ec4ed4c2ac51d7dbf364d5d4e Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 17 May 2023 19:39:54 -0400 Subject: [PATCH 16/37] Add env variable to Python builds to force debug builds (#876) * Add env variable to Python builds to force debug builds This commit updates the setup.py for the python package to add a new env variable `RUST_DEBUG` to force the package to be built in debug mode. By default pip install will build the rust code in release mode, which is the sane default you want for publishing or installing a package. But for local development you typically want to build normally in debug mode. This increases the build speed and also adds additional runtime checking to validate the code is fully working. The tradeoff with this though is the runtime is very poor because the compiler doesn't do any optimization. As part of this commit the tox configuration is updated to default to debug builds. For unit tests the execution time is unchanged, because while it compiles faster that is offset by the slower execution of the tests. However, in general I think it's better to run tests in debug mode by default because it will do runtime validation (e.g. bounds checks overflow detection, etc) which is good to catch in testing. * Apply suggestions from code review Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --- setup.py | 7 ++++++- tox.ini | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 02d3ff6a5..053825cc1 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,11 @@ from setuptools_rust import Binding, RustExtension +# If RUST_DEBUG is set, force compiling in debug mode. Else, use the default behavior of whether +# it's an editable installation. +rustworkx_debug = True if os.getenv("RUSTWORKX_DEBUG") == "1" else None + + def readme(): with open('README.md') as f: return f.read() @@ -25,7 +30,7 @@ def readme(): PKG_PACKAGES = ["rustworkx", "rustworkx.visualization"] PKG_INSTALL_REQUIRES = ['numpy>=1.16.0'] RUST_EXTENSIONS = [RustExtension("rustworkx.rustworkx", "Cargo.toml", - binding=Binding.PyO3)] + binding=Binding.PyO3, debug=rustworkx_debug)] retworkx_readme_compat = """# retworkx diff --git a/tox.ini b/tox.ini index cc246d3f1..7f13c3084 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ setenv = LANGUAGE=en_US LC_ALL=en_US.utf-8 ARGS="-V" + RUST_DEBUG=1 deps = setuptools-rust fixtures @@ -22,6 +23,7 @@ extras = passenv = RETWORKX_TEST_PRESERVE_IMAGES RUSTWORKX_PKG_NAME + RUSTWORKX_DEBUG changedir = {toxinidir}/tests commands = stestr run {posargs} @@ -43,12 +45,14 @@ commands = basepython = python3 setenv = {[testenv]setenv} + RUSTWORKX_DEBUG=1 deps = -r {toxinidir}/docs/source/requirements.txt passenv = {[testenv]passenv} RETWORKX_DEV_DOCS RETWORKX_LEGACY_DOCS + RUST_DEBUG changedir = {toxinidir}/docs commands = python -m ipykernel install --user From 09d5707e8657057e1dcde758f912bc89ea842cf3 Mon Sep 17 00:00:00 2001 From: Eric Arellano <14852634+Eric-Arellano@users.noreply.github.com> Date: Tue, 23 May 2023 20:09:16 -0600 Subject: [PATCH 17/37] Upgrade to qiskit_sphinx_theme 1.12 (#880) * Upgrade to qiskit_sphinx_theme 1.12 * Fix static path option --- .github/workflows/docs_dev.yml | 1 + docs/source/_static/overrides.css | 3 + docs/source/_templates/layout.html | 4 -- docs/source/_templates/sidebar.html | 99 ----------------------------- docs/source/conf.py | 4 +- docs/source/requirements.txt | 4 +- tox.ini | 6 ++ 7 files changed, 15 insertions(+), 106 deletions(-) create mode 100644 docs/source/_static/overrides.css delete mode 100644 docs/source/_templates/sidebar.html diff --git a/.github/workflows/docs_dev.yml b/.github/workflows/docs_dev.yml index c55b0aab8..f0a4b1fca 100644 --- a/.github/workflows/docs_dev.yml +++ b/.github/workflows/docs_dev.yml @@ -1,5 +1,6 @@ name: Docs Publish on: + workflow_dispatch: push: branches: [ main ] diff --git a/docs/source/_static/overrides.css b/docs/source/_static/overrides.css new file mode 100644 index 000000000..c6a64ff83 --- /dev/null +++ b/docs/source/_static/overrides.css @@ -0,0 +1,3 @@ +:root { + --header-height: 0rem; +} diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 515164542..4b7c22a80 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -156,13 +156,9 @@ {{ js_tag(scriptfile) }} {%- endfor %} - - - - -{% endif %} - - - \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 4e8977613..ddc58fb81 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ project = 'rustworkx' copyright = '2021, rustworkx Contributors' - +docs_url_prefix = "ecosystem/rustworkx" # The short X.Y version. version = '0.13.0' @@ -42,7 +42,9 @@ 'sphinx_reredirects', 'qiskit_sphinx_theme', ] +html_static_path = ["_static"] templates_path = ['_templates'] +extra_css_files = ["overrides.css"] pygments_style = 'colorful' diff --git a/docs/source/requirements.txt b/docs/source/requirements.txt index 57816784a..32b3c6359 100644 --- a/docs/source/requirements.txt +++ b/docs/source/requirements.txt @@ -1,10 +1,10 @@ m2r2 -sphinx>=3.0.0 +sphinx>=5.0 jupyter-sphinx pydot pillow>=4.2.1 reno>=3.4.0 -qiskit-sphinx-theme~=1.11.1 +qiskit-sphinx-theme~=1.12.0 matplotlib>=3.4 sphinx-reredirects sphinxemoji diff --git a/tox.ini b/tox.ini index 7f13c3084..69cbd098f 100644 --- a/tox.ini +++ b/tox.ini @@ -59,6 +59,12 @@ commands = jupyter kernelspec list sphinx-build -W -d {toxinidir}/docs/build/.doctrees -b html source build/html {posargs} +[testenv:docs-clean] +skip_install = true +deps = +allowlist_externals = rm +commands = rm -rf {toxinidir}/docs/build {toxinidir}/docs/apiref + [testenv:black] basepython = python3 deps = From afc3627da8f8fb9729bb9887a5c34e5ce526d37a Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Fri, 26 May 2023 08:28:35 -0400 Subject: [PATCH 18/37] Add PageRank (#788) Related to #315 Adds an implementation of the PageRank algorithm using sparse matrices. It uses the sprs crate combined with ndarray to implement a Power Method approach of finding the PageRank. Also, we test this implementation against NetworkX's implementation of the PageRank. We accept all the arguments that NetworkX accepts: tolerance, max_iter, personalization, dangling, etc. * Add the sketch of PageRank * More progress towards pagerank * Use CentralityMapping and FailedToConverge * Finalize PageRank * First test does not run * Remove unwanted triplet * Fix clippy warning * Handle personalization correctly * Add more tests * Cargo fmt * Add scipy to test requirements * Skip SciPy tests in case architecture does not have it * Ignore flake8 errors that do not help * Flake8 * Handle dangling weights * Add more tests * Add nstart argument * Cargo Clippy * Documentation * Fix typo in URL * Update releasenotes/notes/add-pagerank-bef0de7d46026071.yaml Co-authored-by: Matthew Treinish * Tweak pyfunction signature * Add scipy to aarch64 test requirements * Address comments from code review * Clippy is always right --------- Co-authored-by: Matthew Treinish --- .github/workflows/main.yml | 2 +- .github/workflows/wheels.yml | 4 +- Cargo.lock | 86 ++++++- Cargo.toml | 5 + docs/source/api.rst | 10 + .../notes/add-pagerank-bef0de7d46026071.yaml | 29 +++ src/lib.rs | 4 + src/link_analysis.rs | 227 +++++++++++++++++ .../rustworkx_tests/digraph/test_pagerank.py | 236 ++++++++++++++++++ tox.ini | 1 + 10 files changed, 597 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/add-pagerank-bef0de7d46026071.yaml create mode 100644 src/link_analysis.rs create mode 100644 tests/rustworkx_tests/digraph/test_pagerank.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 56db02926..2520c934f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -168,7 +168,7 @@ jobs: - name: Download grcov run: curl -L https://github.com/mozilla/grcov/releases/download/v0.8.7/grcov-x86_64-unknown-linux-gnu.tar.bz2 | tar jxf - - name: Install deps - run: pip install -U setuptools-rust networkx testtools fixtures + run: pip install -U setuptools-rust networkx scipy testtools fixtures - name: Build retworkx run: python setup.py develop env: diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 225c9ceff..b94893ecd 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -67,7 +67,7 @@ jobs: CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest CIBW_SKIP: cp36-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust - CIBW_TEST_REQUIRES: networkx testtools fixtures + CIBW_TEST_REQUIRES: networkx scipy testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests - uses: actions/upload-artifact@v3 with: @@ -109,7 +109,7 @@ jobs: CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest CIBW_SKIP: cp36-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust - CIBW_TEST_REQUIRES: networkx testtools fixtures + CIBW_TEST_REQUIRES: networkx scipy testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests CIBW_ARCHS_LINUX: aarch64 - uses: actions/upload-artifact@v3 diff --git a/Cargo.lock b/Cargo.lock index a3c081365..e0ed61b78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,6 +25,26 @@ dependencies = [ "version_check", ] +[[package]] +name = "alga" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f823d037a7ec6ea2197046bafd4ae150e6bc36f9ca347404f46a46823fa84f2" +dependencies = [ + "approx", + "num-complex 0.2.4", + "num-traits", +] + +[[package]] +name = "approx" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0e60b75072ecd4168020818c0107f2857bb6c4e64252d8d3983f6263b40a5c3" +dependencies = [ + "num-traits", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -166,6 +186,12 @@ version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" +[[package]] +name = "libm" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" + [[package]] name = "lock_api" version = "0.4.9" @@ -216,13 +242,37 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" dependencies = [ "matrixmultiply", - "num-complex", + "num-complex 0.4.3", "num-integer", "num-traits", "rawpointer", "rayon", ] +[[package]] +name = "ndarray-stats" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af5a8477ac96877b5bd1fd67e0c28736c12943aba24eda92b127e036b0c8f400" +dependencies = [ + "indexmap", + "itertools", + "ndarray", + "noisy_float", + "num-integer", + "num-traits", + "rand", +] + +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.3" @@ -234,6 +284,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +dependencies = [ + "autocfg", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.3" @@ -260,6 +320,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -280,7 +341,7 @@ checksum = "96b0fee4571867d318651c24f4a570c3f18408cf95f16ccb576b3ce85496a46e" dependencies = [ "libc", "ndarray", - "num-complex", + "num-complex 0.4.3", "num-integer", "num-traits", "pyo3", @@ -364,7 +425,7 @@ dependencies = [ "libc", "memoffset 0.8.0", "num-bigint", - "num-complex", + "num-complex 0.4.3", "parking_lot", "pyo3-build-config", "pyo3-ffi", @@ -535,8 +596,9 @@ dependencies = [ "hashbrown", "indexmap", "ndarray", + "ndarray-stats", "num-bigint", - "num-complex", + "num-complex 0.4.3", "num-traits", "numpy", "petgraph", @@ -548,6 +610,7 @@ dependencies = [ "rustworkx-core", "serde", "serde_json", + "sprs", ] [[package]] @@ -616,6 +679,21 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +[[package]] +name = "sprs" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea71e48b3eab4c4b153e8e35dcaeac132720809ef68359097b8cb54a18edd70" +dependencies = [ + "alga", + "ndarray", + "num-complex 0.4.3", + "num-traits", + "num_cpus", + "rayon", + "smallvec", +] + [[package]] name = "syn" version = "1.0.104" diff --git a/Cargo.toml b/Cargo.toml index bc95f687e..0a7fafa3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ rayon = "1.6" num-traits = "0.2" num-bigint = "0.4" num-complex = "0.4" +ndarray-stats = "0.5.1" quick-xml = "0.28" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -50,6 +51,10 @@ features = ["rayon"] version = "1.9" features = ["rayon"] +[dependencies.sprs] +version = "^0.11" +features = ["multi_thread"] + [profile.release] lto = 'fat' codegen-units = 1 diff --git a/docs/source/api.rst b/docs/source/api.rst index 44a7dab7a..73383fd90 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -58,6 +58,16 @@ Centrality rustworkx.eigenvector_centrality rustworkx.closeness_centrality +.. _link-analysis: + +Link Analysis +-------------- + +.. autosummary:: + :toctree: apiref + + rustworkx.pagerank + .. _traversal: Traversal diff --git a/releasenotes/notes/add-pagerank-bef0de7d46026071.yaml b/releasenotes/notes/add-pagerank-bef0de7d46026071.yaml new file mode 100644 index 000000000..1fe4fa284 --- /dev/null +++ b/releasenotes/notes/add-pagerank-bef0de7d46026071.yaml @@ -0,0 +1,29 @@ +--- +features: + - | + Added a new function, :func:`~.pagerank` which is used to + compute the PageRank score for all nodes in a given directed graph. + For example: + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import mpl_draw + + graph = rx.generators.directed_hexagonal_lattice_graph(2, 2) + ranks = rx.pagerank(graph) + + # Generate a color list + colors = [] + for node in graph.node_indices(): + pagerank_score = ranks[node] + graph[node] = pagerank_score + colors.append(pagerank_score) + mpl_draw( + graph, + with_labels=True, + node_color=colors, + node_size=650, + labels=lambda x: "{0:.2f}".format(x) + ) + diff --git a/src/lib.rs b/src/lib.rs index 7f941e2e6..dedec6c6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ mod isomorphism; mod iterators; mod json; mod layout; +mod link_analysis; mod matching; mod planar; mod random_graph; @@ -47,6 +48,8 @@ use graphml::*; use isomorphism::*; use json::*; use layout::*; +use link_analysis::*; + use matching::*; use planar::*; use random_graph::*; @@ -485,6 +488,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(read_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; + m.add_wrapped(wrap_pyfunction!(pagerank))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/link_analysis.rs b/src/link_analysis.rs new file mode 100644 index 000000000..d38cf0ce3 --- /dev/null +++ b/src/link_analysis.rs @@ -0,0 +1,227 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// PageRank has many possible personalizations, so we accept them all +#![allow(clippy::too_many_arguments)] + +use pyo3::prelude::*; +use pyo3::Python; + +use crate::digraph::PyDiGraph; +use crate::iterators::CentralityMapping; +use crate::{weight_callable, FailedToConverge}; + +use hashbrown::HashMap; +use ndarray::prelude::*; +use ndarray_stats::DeviationExt; +use petgraph::prelude::*; +use petgraph::visit::IntoEdgeReferences; +use petgraph::visit::NodeIndexable; +use rustworkx_core::dictmap::*; +use sprs::{CsMat, TriMat}; + +/// Computes the PageRank of the nodes in a :class:`~PyDiGraph`. +/// +/// For details on the PageRank, refer to: +/// +/// L. Page, S. Brin, R. Motwani, and T. Winograd. “The PageRank Citation Ranking: Bringing order to the Web”. +/// Stanford Digital Library Technologies Project, (1998). +/// +/// +/// This function uses a power iteration method to compute the PageRank +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm tries to match NetworkX's +/// `pagerank() `__ +/// implementation. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the PageRank. +/// +/// :param PyDiGraph graph: The graph object to run the algorithm on +/// :param float alpha: Damping parameter for PageRank, default=0.85. +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified ``default_weight`` will be used as the weight +/// for every edge in ``graph`` +/// :param dict nstart: Optional starting value of PageRank iteration for each node. +/// :param dict personalization: An optional dictionary representing the personalization +/// vector for a subset of nodes. At least one personalization entry must be non-zero. +/// If not specified, a nodes personalization value will be zero. By default, +/// a uniform distribution is used. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-6 is used. +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 100 is used. +/// :param dict dangling: An optional dictionary for the outedges to be assigned to any "dangling" nodes, +/// i.e., nodes without any outedges. The dict key is the node the outedge points to and the dict +/// value is the weight of that outedge. By default, dangling nodes are given outedges according to +/// the personalization vector (uniform if not specified). This must be selected to result in an irreducible +/// transition matrix. It may be common to have the dangling dict to be the same as the personalization dict. +/// +/// :returns: a read-only dict-like object whose keys are the node indices and values are the +/// PageRank score for that node. +/// :rtype: CentralityMapping +#[pyfunction( + signature = ( + graph, + alpha=0.85, + weight_fn=None, + nstart=None, + personalization=None, + tol=1e-6, + max_iter=100, + dangling=None, + ) +)] +#[pyo3( + text_signature = "(graph, /, alpha=0.85, weight_fn=None, nstart=None, personalization=None, tol=1.0e-6, max_iter=100)" +)] +pub fn pagerank( + py: Python, + graph: &PyDiGraph, + alpha: f64, + weight_fn: Option, + nstart: Option>, + personalization: Option>, + tol: f64, + max_iter: usize, + dangling: Option>, +) -> PyResult { + // we use the node bound to make the code work if nodes were removed + let n = graph.graph.node_count(); + let mat_size = graph.graph.node_bound(); + let node_indices: Vec = graph.graph.node_indices().map(|x| x.index()).collect(); + + // Handle empty case + if n == 0 { + return Ok(CentralityMapping { + centralities: DictMap::new(), + }); + } + + // Grab the graph weights from Python to Rust + let mut in_weights: HashMap<(usize, usize), f64> = + HashMap::with_capacity(graph.graph.edge_count()); + let mut out_weights: Vec = vec![0.0; mat_size]; + let default_weight: f64 = 1.0; + + for edge in graph.graph.edge_references() { + let i = NodeIndexable::to_index(&graph.graph, edge.source()); + let j = NodeIndexable::to_index(&graph.graph, edge.target()); + let weight = edge.weight().clone(); + + let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?; + out_weights[i] += edge_weight; + *in_weights.entry((i, j)).or_insert(0.0) += edge_weight; + } + + // Create sparse Google Matrix that describes the Markov Chain process + let mut a = TriMat::new((mat_size, mat_size)); + for ((i, j), weight) in in_weights.into_iter() { + a.add_triplet(j, i, weight / out_weights[i]); + } + let a: CsMat<_> = a.to_csr(); + + // Vector with probabilities for the Markov Chain process + let mut popularity = Array1::::zeros(mat_size); + let default_pop = (n as f64).recip(); + + // Handle custom start + if let Some(nstart) = nstart { + for i in &node_indices { + popularity[*i] = *nstart.get(i).unwrap_or(&0.0); + } + let pop_sum = popularity.sum(); + popularity /= pop_sum; + } else { + for i in &node_indices { + popularity[*i] = default_pop; + } + } + + // Handle personalization + let personalized_array: Array1 = match personalization { + Some(personalization) => { + let mut personalized_array = Array1::::zeros(mat_size); + for i in &node_indices { + personalized_array[*i] = *personalization.get(i).unwrap_or(&0.0); + } + let p_sum = personalized_array.sum(); + personalized_array /= p_sum; + personalized_array + } + None => { + let mut personalized_array = Array1::::zeros(mat_size); + for i in &node_indices { + personalized_array[*i] = default_pop; + } + personalized_array + } + }; + let damping = (1.0 - alpha) * &personalized_array; + + // Handle dangling nodes i.e. nodes that point nowhere + let is_dangling = (0..mat_size) + .map(|i| out_weights[i] == 0.0) + .collect::>(); + let dangling_weights: Array1 = match dangling { + Some(dangling) => { + let mut dangling_weights = Array1::::zeros(mat_size); + for i in &node_indices { + dangling_weights[*i] = *dangling.get(i).unwrap_or(&0.0); + } + let d_sum = dangling_weights.sum(); + dangling_weights /= d_sum; + dangling_weights + } + None => personalized_array, + }; + + // Power Method iteration for the Google Matrix + let mut has_converged = false; + for _ in 0..max_iter { + let dangling_sum: f64 = is_dangling + .iter() + .zip(popularity.iter()) + .map(|(cond, pop)| if *cond { *pop } else { 0.0 }) + .sum(); + let new_popularity = + alpha * ((&a * &popularity) + (dangling_sum * &dangling_weights)) + &damping; + let norm: f64 = new_popularity.l1_dist(&popularity).unwrap(); + if norm < (n as f64) * tol { + has_converged = true; + break; + } else { + popularity = new_popularity; + } + } + + // Convert to custom return type + if !has_converged { + return Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))); + } + + let out_map: DictMap = graph + .graph + .node_indices() + .map(|x| (x.index(), popularity[x.index()])) + .collect(); + + Ok(CentralityMapping { + centralities: out_map, + }) +} diff --git a/tests/rustworkx_tests/digraph/test_pagerank.py b/tests/rustworkx_tests/digraph/test_pagerank.py new file mode 100644 index 000000000..6cc3b9274 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_pagerank.py @@ -0,0 +1,236 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# These tests are adapated from the networkx test cases: +# https://github.com/networkx/networkx/blob/cea310f9066efc0d5ff76f63d33dbc3eefe61f6b/networkx/algorithms/link_analysis/tests/test_pagerank.py + +import unittest + +import rustworkx +import networkx as nx + + +class TestPageRank(unittest.TestCase): + def setUp(self) -> None: + try: + # required for networkx.pagerank to work + import scipy + + self.assertIsNotNone(scipy.__version__) + except ModuleNotFoundError: + self.skipTest("SciPy is not installed, skipping PageRank tests") + + def test_with_dangling_node(self): + edges = [ + (0, 1), + (0, 2), + (2, 0), + (2, 1), + (2, 4), + (3, 4), + (3, 5), + (4, 3), + (4, 5), + (5, 4), + ] # node 1 is dangling because it does not point to anyone + + rx_graph = rustworkx.PyDiGraph() + nx_graph = nx.DiGraph() + + rx_graph.extend_from_edge_list(edges) + nx_graph.add_edges_from(edges) + + alpha = 0.9 + tol = 1.0e-8 + + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha, tol=tol) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha, tol=tol) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_with_dangling_node_and_argument(self): + edges = [ + (0, 1), + (0, 2), + (2, 0), + (2, 1), + (2, 4), + (3, 4), + (3, 5), + (4, 3), + (4, 5), + (5, 4), + ] # node 1 is dangling because it does not point to anyone + + rx_graph = rustworkx.PyDiGraph() + nx_graph = nx.DiGraph() + + rx_graph.extend_from_edge_list(edges) + nx_graph.add_edges_from(edges) + + dangling = {0: 0, 1: 1, 2: 2, 3: 0, 5: 0} + + alpha = 0.85 + tol = 1.0e-8 + + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha, tol=tol, dangling=dangling) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha, tol=tol, dangling=dangling) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_empty(self): + graph = rustworkx.PyDiGraph() + ranks = rustworkx.pagerank(graph) + self.assertEqual({}, ranks) + + def test_one_node(self): + graph = rustworkx.PyDiGraph() + graph.add_node(0) + ranks = rustworkx.pagerank(graph) + self.assertEqual({0: 1}, ranks) + + def test_cycle_graph(self): + graph = rustworkx.generators.directed_cycle_graph(100) + ranks = rustworkx.pagerank(graph) + + for v in graph.node_indices(): + self.assertAlmostEqual(ranks[v], 1 / 100.0, delta=1.0e-4) + + def test_with_removed_node(self): + graph = rustworkx.PyDiGraph() + + edges = [ + (0, 1), + (1, 2), + (2, 3), + (3, 0), + (4, 0), + (4, 1), + (4, 2), + (0, 4), + ] + graph.extend_from_edge_list(edges) + graph.remove_node(3) + + ranks = rustworkx.pagerank(graph) + + expected_ranks = { + 0: 0.17401467654615052, + 1: 0.2479710438690554, + 2: 0.3847906219106203, + 4: 0.19322365767417365, + } + + for v in graph.node_indices(): + self.assertAlmostEqual(ranks[v], expected_ranks[v], delta=1.0e-4) + + def test_pagerank_with_nstart(self): + rx_graph = rustworkx.generators.directed_complete_graph(4) + nstart = {0: 0.5, 1: 0.5, 2: 0, 3: 0} + alpha = 0.85 + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha, nstart=nstart) + nx_graph = nx.DiGraph(list(rx_graph.edge_list())) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha, nstart=nstart) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_pagerank_with_personalize(self): + rx_graph = rustworkx.generators.directed_complete_graph(4) + personalize = {0: 0, 1: 0, 2: 0, 3: 1} + alpha = 0.85 + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha, personalization=personalize) + nx_graph = nx.DiGraph(list(rx_graph.edge_list())) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha, personalization=personalize) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_pagerank_with_personalize_missing(self): + rx_graph = rustworkx.generators.directed_complete_graph(4) + personalize = {3: 1} + alpha = 0.85 + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha, personalization=personalize) + nx_graph = nx.DiGraph(list(rx_graph.edge_list())) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha, personalization=personalize) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_multi_digraph(self): + rx_graph = rustworkx.PyDiGraph() + rx_graph.extend_from_edge_list( + [ + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (1, 2), + (2, 1), + (1, 2), + (2, 1), + (2, 3), + (3, 2), + (2, 3), + (3, 2), + ] + ) + nx_graph = nx.MultiDiGraph(list(rx_graph.edge_list())) + + alpha = 0.9 + rx_ranks = rustworkx.pagerank(rx_graph, alpha=alpha) + nx_ranks = nx.pagerank(nx_graph, alpha=alpha) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_ranks[v], nx_ranks[v], delta=1.0e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.directed_complete_graph(4) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.pagerank(graph, max_iter=0) + + def test_multi_digraph_versus_weighted(self): + multi_graph = rustworkx.PyDiGraph() + multi_graph.extend_from_edge_list( + [ + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (1, 2), + (2, 1), + (1, 2), + (2, 1), + (2, 3), + (3, 2), + (2, 3), + (3, 2), + ] + ) + + weighted_graph = rustworkx.PyDiGraph() + weighted_graph.extend_from_weighted_edge_list( + [(0, 1, 3), (1, 0, 3), (1, 2, 2), (2, 1, 2), (2, 3, 2), (3, 2, 2)] + ) + + alpha = 0.85 + ranks_multi = rustworkx.pagerank(multi_graph, alpha=alpha, weight_fn=lambda _: 1.0) + ranks_weight = rustworkx.pagerank(weighted_graph, alpha=alpha, weight_fn=float) + + for v in multi_graph.node_indices(): + self.assertAlmostEqual(ranks_multi[v], ranks_weight[v], delta=1.0e-4) diff --git a/tox.ini b/tox.ini index 69cbd098f..aeab7f94c 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,7 @@ deps = fixtures testtools>=2.5.0 networkx>=2.5 + scipy>=1.7 stestr extras = mpl From e1ca8eeec05b087a1d88f3bbb0d649a2f11f05d7 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sun, 28 May 2023 13:46:53 -0400 Subject: [PATCH 19/37] Add longest_simple_path() function to rustworkx (#731) * Add longest_simple_path() function to rustworkx This commit adds a new function longest_simple_path() which is used to find the longest simple path in the graph. It internally is running basically the same algorithm as all_pairs_all_simple_paths but instead of building a giant nested mapping of all the simple paths found in the graph it instead filters the return to just the longest path found. This is useful if you're quickly trying to find the longest path in the graph as it avoids the overhead of building the return object and then iterating over it again. * Fix doc example of the alternative python example This commit fixes the code example for the python space example of how to compute the longest path in a graph (at the cose of slower runtime and a lot more memory). Co-authored-by: Alexander Ivrii * Deduplicate function implementations between graph types * Simplify logic for selection of longest path * Further simplify iterator * Better handle empty path lists * Remove clone() usage * Add dedidcated longest path function to rustworkx-core * Update add-longest-simple-path-afdc4538c49bc38f.yaml --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Co-authored-by: Alexander Ivrii --- docs/source/api.rst | 3 + ...-longest-simple-path-afdc4538c49bc38f.yaml | 33 +++ .../src/connectivity/all_simple_paths.rs | 226 +++++++++++++++++- rustworkx-core/src/connectivity/mod.rs | 4 +- rustworkx/__init__.py | 43 ++++ src/connectivity/mod.rs | 117 ++++++++- src/lib.rs | 2 + .../digraph/test_all_simple_paths.py | 39 +++ .../graph/test_all_simple_paths.py | 29 +++ 9 files changed, 486 insertions(+), 10 deletions(-) create mode 100644 releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml diff --git a/docs/source/api.rst b/docs/source/api.rst index 73383fd90..3d6fc88c8 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -170,6 +170,7 @@ Connectivity and Cycles rustworkx.all_simple_paths rustworkx.all_pairs_all_simple_paths rustworkx.stoer_wagner_min_cut + rustworkx.longest_simple_path .. _graph-ops: @@ -340,6 +341,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_bfs_search rustworkx.digraph_dijkstra_search rustworkx.digraph_node_link_json + rustworkx.digraph_longest_simple_path .. _api-functions-pygraph: @@ -397,6 +399,7 @@ typed API based on the data type. rustworkx.graph_bfs_search rustworkx.graph_dijkstra_search rustworkx.graph_node_link_json + rustworkx.graph_longest_simple_path Exceptions ========== diff --git a/releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml b/releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml new file mode 100644 index 000000000..fc1eb9420 --- /dev/null +++ b/releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml @@ -0,0 +1,33 @@ +--- +features: + - | + Added a new function, :func:`~.longest_simple_path` which is used to search + all the simple paths between all pairs of nodes in a graph and return + the longest path found. For example: + + .. jupyter-execute:: + + import rustworkx as rx + + graph = rx.generators.binomial_tree_graph(5) + longest_path = rx.longest_simple_path(graph) + print(longest_path) + + Then visualizing the nodes in the longest path found: + + .. jupyter-execute:: + + from rustworkx.visualization import mpl_draw + + path_set = set(longest_path) + colors = [] + for index in range(len(graph)): + if index in path_set: + colors.append('r') + else: + colors.append('#1f78b4') + mpl_draw(graph, node_color=colors) + - | + Added a new function ``longest_simple_path_multiple_targets()`` to + rustworkx-core. This function will return the longest simple path from a + source node to a ``HashSet`` of target nodes. diff --git a/rustworkx-core/src/connectivity/all_simple_paths.rs b/rustworkx-core/src/connectivity/all_simple_paths.rs index 3d75f5eb3..4a7293ce2 100644 --- a/rustworkx-core/src/connectivity/all_simple_paths.rs +++ b/rustworkx-core/src/connectivity/all_simple_paths.rs @@ -136,9 +136,96 @@ where output } +/// Returns the longest of all the simple paths from `from` node to all nodes in `to`, which contains at least `min_intermediate_nodes` nodes +/// and at most `max_intermediate_nodes`, if given, or limited by the graph's order otherwise. The simple path is a path without repetitions. +/// +/// # Example +/// ``` +/// use petgraph::prelude::*; +/// use hashbrown::HashSet; +/// use rustworkx_core::connectivity::longest_simple_path_multiple_targets; +/// +/// let mut graph = DiGraph::<&str, i32>::new(); +/// +/// let a = graph.add_node("a"); +/// let b = graph.add_node("b"); +/// let c = graph.add_node("c"); +/// let d = graph.add_node("d"); +/// +/// graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (a, b, 1), (b, d, 1)]); +/// +/// let mut to_set = HashSet::new(); +/// to_set.insert(d); +/// +/// let path = longest_simple_path_multiple_targets(&graph, a, &to_set); +/// +/// let expected = vec![a, b, c, d]; +/// assert_eq!(path.unwrap(), expected); +/// ``` +pub fn longest_simple_path_multiple_targets( + graph: G, + from: G::NodeId, + to: &HashSet, +) -> Option> +where + G: NodeCount, + G: IntoNeighborsDirected, + G::NodeId: Eq + Hash, +{ + // list of visited nodes + let mut visited: IndexSet = IndexSet::from_iter(Some(from)); + // list of childs of currently exploring path nodes, + // last elem is list of childs of last visited node + let mut stack = vec![graph.neighbors_directed(from, Outgoing)]; + + let mut output_path: Option> = None; + + let update_path = |new_path: Vec, + output_path: &Option>| + -> Option> { + match output_path.as_ref() { + None => Some(new_path), + Some(path) => { + if path.len() < new_path.len() { + Some(new_path) + } else { + None + } + } + } + }; + + while let Some(children) = stack.last_mut() { + if let Some(child) = children.next() { + if !visited.contains(&child) { + if to.contains(&child) { + let new_path: Vec = + visited.iter().chain(&[child]).copied().collect(); + let temp = update_path(new_path, &output_path); + if temp.is_some() { + output_path = temp; + } + } + visited.insert(child); + if to.iter().any(|n| !visited.contains(n)) { + stack.push(graph.neighbors_directed(child, Outgoing)); + } else { + visited.pop(); + } + } + } else { + stack.pop(); + visited.pop(); + } + } + output_path +} + #[cfg(test)] mod tests { - use crate::connectivity::all_simple_paths_multiple_targets; + use crate::connectivity::{ + all_simple_paths_multiple_targets, longest_simple_path_multiple_targets, + }; use hashbrown::HashSet; use petgraph::prelude::*; @@ -446,4 +533,141 @@ mod tests { &vec![vec![b, d], vec![b, f, e, d], vec![b, c, d]] ); } + + #[test] + fn test_longest_simple_path() { + // create a path graph + let mut graph = Graph::new_undirected(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + + graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (d, e, 1)]); + + let mut to_set = HashSet::new(); + to_set.insert(d); + + let path = longest_simple_path_multiple_targets(&graph, a, &to_set); + + assert_eq!(path.unwrap(), vec![a, b, c, d]); + } + + #[test] + fn test_longest_simple_path_with_two_targets_emits_two_paths() { + // create a path graph + let mut graph = Graph::new_undirected(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + + graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (d, e, 1), (c, e, 1)]); + + let mut to_set = HashSet::new(); + to_set.insert(d); + to_set.insert(e); + + let path = longest_simple_path_multiple_targets(&graph, a, &to_set); + + assert_eq!(path.unwrap(), vec![a, b, c, e, d]); + } + + #[test] + fn test_digraph_longest_simple_path_with_two_targets_emits_two_paths() { + // create a path graph + let mut graph = Graph::new(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + + graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (d, e, 1), (c, e, 1)]); + + let mut to_set = HashSet::new(); + to_set.insert(d); + to_set.insert(e); + + let path = longest_simple_path_multiple_targets(&graph, a, &to_set); + + assert_eq!(path.unwrap(), vec![a, b, c, d, e]); + } + + #[test] + fn test_longest_simple_paths_with_two_targets_in_line_emits_two_paths() { + // create a path graph + let mut graph = Graph::new_undirected(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + + graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (d, e, 1)]); + + let mut to_set = HashSet::new(); + to_set.insert(c); + to_set.insert(d); + + let path = longest_simple_path_multiple_targets(&graph, a, &to_set); + + assert_eq!(path.unwrap(), vec![a, b, c, d]); + } + + #[test] + fn test_longest_simple_paths_source_target() { + // create a path graph + let mut graph = Graph::new_undirected(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + + graph.extend_with_edges(&[(a, b, 1), (b, c, 1), (c, d, 1), (d, e, 1)]); + + let mut to_set = HashSet::new(); + to_set.insert(a); + + let path = longest_simple_path_multiple_targets(&graph, a, &to_set); + + assert_eq!(path, None); + } + + #[test] + fn test_longest_simple_paths_on_non_trivial_graph() { + // create a path graph + let mut graph = Graph::new(); + let a = graph.add_node(0); + let b = graph.add_node(1); + let c = graph.add_node(2); + let d = graph.add_node(3); + let e = graph.add_node(4); + let f = graph.add_node(5); + + graph.extend_with_edges(&[ + (a, b, 1), + (b, c, 1), + (c, d, 1), + (d, e, 1), + (e, f, 1), + (a, f, 1), + (b, f, 1), + (b, d, 1), + (f, e, 1), + (e, c, 1), + (e, d, 1), + ]); + + let mut to_set = HashSet::new(); + to_set.insert(c); + to_set.insert(d); + + let path = longest_simple_path_multiple_targets(&graph, b, &to_set); + + assert_eq!(path.unwrap(), vec![b, f, e, c, d],); + } } diff --git a/rustworkx-core/src/connectivity/mod.rs b/rustworkx-core/src/connectivity/mod.rs index 828dfb4b6..b66405c55 100644 --- a/rustworkx-core/src/connectivity/mod.rs +++ b/rustworkx-core/src/connectivity/mod.rs @@ -21,7 +21,9 @@ mod cycle_basis; mod find_cycle; mod min_cut; -pub use all_simple_paths::all_simple_paths_multiple_targets; +pub use all_simple_paths::{ + all_simple_paths_multiple_targets, longest_simple_path_multiple_targets, +}; pub use biconnected::articulation_points; pub use chain::chain_decomposition; pub use conn_components::bfs_undirected; diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 40e3cd6c7..c1c31750f 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2504,3 +2504,46 @@ def _graph_node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, e return graph_node_link_json( graph, path=path, graph_attrs=graph_attrs, node_attrs=node_attrs, edge_attrs=edge_attrs ) + + +@functools.singledispatch +def longest_simple_path(graph): + """Return a longest simple path in the graph + + This function searches computes all pairs of all simple paths and returns + a path of the longest length from that set. It is roughly equivalent to + running something like:: + + from rustworkx import all_pairs_all_simple_paths + + max((y.values for y in all_pairs_all_simple_paths(graph).values()), key=lambda x: len(x)) + + but this function will be more efficient than using ``max()`` as the search + is evaluated in parallel before returning to Python. In the case of multiple + paths of the same maximum length being present in the graph only one will be + provided. There are no guarantees on which of the multiple longest paths + will be returned (as it is determined by the parallel execution order). This + is a tradeoff to improve runtime performance. If a stable return is required + in such case consider using the ``max()`` equivalent above instead. + + This function is multithreaded and will launch a thread pool with threads + equal to the number of CPUs by default. You can tune the number of threads + with the ``RAYON_NUM_THREADS`` environment variable. For example, setting + ``RAYON_NUM_THREADS=4`` would limit the thread pool to 4 threads. + + :param PyGraph graph: The graph to find the longest path in + + :returns: A sequence of node indices that represent the longest simple graph + found in the graph. If the graph is empty ``None`` will be returned instead. + :rtype: NodeIndices + """ + + +@longest_simple_path.register(PyDiGraph) +def _digraph_longest_simple_path(graph): + return digraph_longest_simple_path(graph) + + +@longest_simple_path.register(PyGraph) +def _graph_longest_simple_path(graph): + return graph_longest_simple_path(graph) diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index e17d30733..70cfbb6d2 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -16,26 +16,29 @@ mod all_pairs_all_simple_paths; mod johnson_simple_cycles; use super::{ - digraph, get_edge_iter_with_weights, graph, iterators::NodeIndices, score, weight_callable, - InvalidNode, NullGraph, + digraph, get_edge_iter_with_weights, graph, score, weight_callable, InvalidNode, NullGraph, }; use hashbrown::{HashMap, HashSet}; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -use pyo3::types::PyDict; -use pyo3::Python; - use petgraph::algo; use petgraph::stable_graph::NodeIndex; use petgraph::unionfind::UnionFind; use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::Python; +use rayon::prelude::*; use ndarray::prelude::*; use numpy::IntoPyArray; -use crate::iterators::{AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList}; +use crate::iterators::{ + AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices, +}; +use crate::{EdgeType, StablePyGraph}; + use rustworkx_core::connectivity; /// Return a list of cycles which form a basis for cycles of a given PyGraph @@ -628,6 +631,104 @@ pub fn graph_all_pairs_all_simple_paths( )) } +fn longest_simple_path( + graph: &StablePyGraph, +) -> Option { + if graph.node_count() == 0 { + return None; + } else if graph.edge_count() == 0 { + return Some(NodeIndices { + nodes: vec![graph.node_indices().next()?.index()], + }); + } + let node_indices: Vec = graph.node_indices().collect(); + let node_index_set = node_indices.iter().copied().collect(); + Some(NodeIndices { + nodes: node_indices + .par_iter() + .filter_map(|u| { + connectivity::longest_simple_path_multiple_targets(graph, *u, &node_index_set) + }) + .max_by_key(|x| x.len()) + .unwrap() + .into_iter() + .map(|x| x.index()) + .collect(), + }) +} + +/// Return a longest simple path in the graph +/// +/// This function searches computes all pairs of all simple paths and returns +/// a path of the longest length from that set. It is roughly equivalent to +/// running something like:: +/// +/// from rustworkx import all_pairs_all_simple_paths +/// +/// max((y.values for y in all_pairs_all_simple_paths(graph).values()), key=lambda x: len(x)) +/// +/// but this function will be more efficient than using ``max()`` as the search +/// is evaluated in parallel before returning to Python. In the case of multiple +/// paths of the same maximum length being present in the graph only one will be +/// provided. There are no guarantees on which of the multiple longest paths +/// will be returned (as it is determined by the parallel execution order). This +/// is a tradeoff to improve runtime performance. If a stable return is required +/// in such case consider using the ``max()`` equivalent above instead. +/// +/// This function is multithreaded and will launch a thread pool with threads +/// equal to the number of CPUs by default. You can tune the number of threads +/// with the ``RAYON_NUM_THREADS`` environment variable. For example, setting +/// ``RAYON_NUM_THREADS=4`` would limit the thread pool to 4 threads. +/// +/// :param PyDiGraph graph: The graph to find the longest path in +/// +/// :returns: A sequence of node indices that represent the longest simple graph +/// found in the graph. If the graph is empty ``None`` will be returned instead. +/// :rtype: NodeIndices +#[pyfunction] +#[pyo3(text_signature = "(graph, /)")] +pub fn digraph_longest_simple_path(graph: &digraph::PyDiGraph) -> Option { + longest_simple_path(&graph.graph) +} + +/// Return a longest simple path in the graph +/// +/// This function searches computes all pairs of all simple paths and returns +/// a path of the longest length from that set. It is roughly equivalent to +/// running something like:: +/// +/// from rustworkx import all_pairs_all_simple_paths +/// +/// simple_path_pairs = rx.all_pairs_all_simple_paths(graph) +/// longest_path = max( +/// (u for y in simple_path_pairs.values() for z in y.values() for u in z), +/// key=lambda x: len(x), +/// ) +/// +/// but this function will be more efficient than using ``max()`` as the search +/// is evaluated in parallel before returning to Python. In the case of multiple +/// paths of the same maximum length being present in the graph only one will be +/// provided. There are no guarantees on which of the multiple longest paths +/// will be returned (as it is determined by the parallel execution order). This +/// is a tradeoff to improve runtime performance. If a stable return is required +/// in such case consider using the ``max()`` equivalent above instead. +/// +/// This function is multithreaded and will launch a thread pool with threads +/// equal to the number of CPUs by default. You can tune the number of threads +/// with the ``RAYON_NUM_THREADS`` environment variable. For example, setting +/// ``RAYON_NUM_THREADS=4`` would limit the thread pool to 4 threads. +/// +/// :param PyGraph graph: The graph to find the longest path in +/// +/// :returns: A sequence of node indices that represent the longest simple graph +/// found in the graph. If the graph is empty ``None`` will be returned instead. +/// :rtype: NodeIndices +#[pyfunction] +#[pyo3(text_signature = "(graph, /)")] +pub fn graph_longest_simple_path(graph: &graph::PyGraph) -> Option { + longest_simple_path(&graph.graph) +} + /// Return the core number for each node in the graph. /// /// A k-core is a maximal subgraph that contains nodes of degree k or more. diff --git a/src/lib.rs b/src/lib.rs index dedec6c6e..2740cc1df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -392,6 +392,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(graph_adjacency_matrix))?; m.add_wrapped(wrap_pyfunction!(graph_all_pairs_all_simple_paths))?; m.add_wrapped(wrap_pyfunction!(digraph_all_pairs_all_simple_paths))?; + m.add_wrapped(wrap_pyfunction!(graph_longest_simple_path))?; + m.add_wrapped(wrap_pyfunction!(digraph_longest_simple_path))?; m.add_wrapped(wrap_pyfunction!(graph_all_simple_paths))?; m.add_wrapped(wrap_pyfunction!(digraph_all_simple_paths))?; m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_paths))?; diff --git a/tests/rustworkx_tests/digraph/test_all_simple_paths.py b/tests/rustworkx_tests/digraph/test_all_simple_paths.py index 81f9f8c95..c63ea2c41 100644 --- a/tests/rustworkx_tests/digraph/test_all_simple_paths.py +++ b/tests/rustworkx_tests/digraph/test_all_simple_paths.py @@ -305,3 +305,42 @@ def test_all_simple_path_no_path(self): def test_all_simple_paths_empty(self): self.assertEqual({}, rustworkx.all_pairs_all_simple_paths(rustworkx.PyDiGraph())) + + +class TestDiGraphLongestSimplePaths(unittest.TestCase): + def setUp(self): + super().setUp() + self.edges = [ + (0, 1), + (0, 2), + (0, 3), + (1, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 2), + (3, 4), + (4, 2), + (4, 5), + (5, 2), + (5, 3), + ] + + def test_all_simple_paths(self): + dag = rustworkx.PyDAG() + for i in range(6): + dag.add_node(i) + dag.add_edges_from_no_data(self.edges) + res = rustworkx.longest_simple_path(dag) + expected = {(0, 1, 2, 3, 4, 5), (0, 1, 3, 2, 4, 5), (0, 1, 2, 4, 5, 3), (0, 1, 3, 4, 5, 2)} + self.assertIn(tuple(res), expected) + + def test_all_simple_path_no_path(self): + dag = rustworkx.PyDAG() + dag.add_node(0) + dag.add_node(1) + res = rustworkx.longest_simple_path(dag) + self.assertEqual([0], res) + + def test_all_simple_paths_empty(self): + self.assertIsNone(rustworkx.longest_simple_path(rustworkx.PyDiGraph())) diff --git a/tests/rustworkx_tests/graph/test_all_simple_paths.py b/tests/rustworkx_tests/graph/test_all_simple_paths.py index fe13c13f6..b179be29a 100644 --- a/tests/rustworkx_tests/graph/test_all_simple_paths.py +++ b/tests/rustworkx_tests/graph/test_all_simple_paths.py @@ -243,3 +243,32 @@ def test_all_simple_path_no_path(self): def test_all_simple_paths_empty(self): self.assertEqual({}, rustworkx.all_pairs_all_simple_paths(rustworkx.PyGraph())) + + +class TestGraphLongestSimplePath(unittest.TestCase): + def setUp(self): + super().setUp() + self.graph = rustworkx.generators.cycle_graph(4) + + def test_all_simple_paths(self): + res = rustworkx.longest_simple_path(self.graph) + expected = { + (0, 3, 2, 1), + (0, 1, 2, 3), + (1, 0, 3, 2), + (1, 2, 3, 0), + (2, 1, 0, 3), + (2, 3, 0, 1), + (3, 0, 1, 2), + (3, 2, 1, 0), + } + self.assertIn(tuple(res), expected) + + def test_all_simple_path_no_path(self): + graph = rustworkx.PyGraph() + graph.add_node(0) + graph.add_node(1) + self.assertEqual([0], rustworkx.longest_simple_path(graph)) + + def test_all_simple_paths_empty(self): + self.assertIsNone(rustworkx.longest_simple_path(rustworkx.PyGraph())) From 1518b1e97b3de7b936678e993724bb93a5dda278 Mon Sep 17 00:00:00 2001 From: Alexander Ivrii Date: Wed, 31 May 2023 00:04:04 +0300 Subject: [PATCH 20/37] moving greedy-color to rustworkx-core (#875) * moving greedy-color to rustworkx-core * doc fix * reno and doc fix * more docs fixes * merge fixes and docs updates * moving coloring tests inside the module * release notes according to review suggestions * renaming to greedy_node_color * Fix doc warning --------- Co-authored-by: Matthew Treinish --- ...migrate-greedy-color-c3239f35840eec18.yaml | 5 + rustworkx-core/src/coloring.rs | 149 ++++++++++++++++++ rustworkx-core/src/lib.rs | 2 + src/coloring.rs | 42 +---- 4 files changed, 160 insertions(+), 38 deletions(-) create mode 100644 releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml create mode 100644 rustworkx-core/src/coloring.rs diff --git a/releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml b/releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml new file mode 100644 index 000000000..a63ea6ca3 --- /dev/null +++ b/releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Added a new function, ``greedy_node_color``, to ``rustworkx-core`` in a new + ``coloring`` module. It colors a graph using a greedy graph coloring algorithm. diff --git a/rustworkx-core/src/coloring.rs b/rustworkx-core/src/coloring.rs new file mode 100644 index 000000000..0bea3451b --- /dev/null +++ b/rustworkx-core/src/coloring.rs @@ -0,0 +1,149 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::cmp::Reverse; +use std::hash::Hash; + +use crate::dictmap::*; +use hashbrown::{HashMap, HashSet}; +use petgraph::visit::{EdgeRef, IntoEdges, IntoNodeIdentifiers, NodeCount}; +use rayon::prelude::*; + +/// Color a graph using a greedy graph coloring algorithm. +/// +/// This function uses a `largest-first` strategy as described in: +/// +/// Adrian Kosowski, and Krzysztof Manuszewski, Classical Coloring of Graphs, +/// Graph Colorings, 2-19, 2004. ISBN 0-8218-3458-4. +/// +/// to color the nodes with higher degree first. +/// +/// The coloring problem is NP-hard and this is a heuristic algorithm +/// which may not return an optimal solution. +/// +/// Arguments: +/// +/// * `graph` - The graph object to run the algorithm on +/// +/// # Example +/// ```rust +/// +/// use petgraph::graph::Graph; +/// use petgraph::graph::NodeIndex; +/// use petgraph::Undirected; +/// use rustworkx_core::dictmap::*; +/// use rustworkx_core::coloring::greedy_node_color; +/// +/// let g = Graph::<(), (), Undirected>::from_edges(&[(0, 1), (0, 2)]); +/// let colors = greedy_node_color(&g); +/// let mut expected_colors = DictMap::new(); +/// expected_colors.insert(NodeIndex::new(0), 0); +/// expected_colors.insert(NodeIndex::new(1), 1); +/// expected_colors.insert(NodeIndex::new(2), 1); +/// assert_eq!(colors, expected_colors); +/// ``` +/// +/// +pub fn greedy_node_color(graph: G) -> DictMap +where + G: NodeCount + IntoNodeIdentifiers + IntoEdges, + G::NodeId: Hash + Eq + Send + Sync, +{ + let mut colors: DictMap = DictMap::new(); + let mut node_vec: Vec = graph.node_identifiers().collect(); + + let mut sort_map: HashMap = HashMap::with_capacity(graph.node_count()); + for k in node_vec.iter() { + sort_map.insert(*k, graph.edges(*k).count()); + } + node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k))); + + for node in node_vec { + let mut neighbor_colors: HashSet = HashSet::new(); + for edge in graph.edges(node) { + let target = edge.target(); + let existing_color = match colors.get(&target) { + Some(color) => color, + None => continue, + }; + neighbor_colors.insert(*existing_color); + } + let mut current_color: usize = 0; + loop { + if !neighbor_colors.contains(¤t_color) { + break; + } + current_color += 1; + } + colors.insert(node, current_color); + } + + colors +} + +#[cfg(test)] + +mod test_node_coloring { + + use crate::coloring::greedy_node_color; + use crate::dictmap::DictMap; + use crate::petgraph::Graph; + + use petgraph::graph::NodeIndex; + use petgraph::Undirected; + + #[test] + fn test_greedy_node_color_empty_graph() { + // Empty graph + let graph = Graph::<(), (), Undirected>::new_undirected(); + let colors = greedy_node_color(&graph); + let expected_colors: DictMap = [].into_iter().collect(); + assert_eq!(colors, expected_colors); + } + + #[test] + fn test_greedy_node_color_simple_graph() { + // Simple graph + let graph = Graph::<(), (), Undirected>::from_edges(&[(0, 1), (0, 2)]); + let colors = greedy_node_color(&graph); + let expected_colors: DictMap = [ + (NodeIndex::new(0), 0), + (NodeIndex::new(1), 1), + (NodeIndex::new(2), 1), + ] + .into_iter() + .collect(); + assert_eq!(colors, expected_colors); + } + + #[test] + fn test_greedy_node_color_simple_graph_large_degree() { + // Graph with multiple edges + let graph = Graph::<(), (), Undirected>::from_edges(&[ + (0, 1), + (0, 2), + (0, 2), + (0, 2), + (0, 2), + (0, 2), + ]); + let colors = greedy_node_color(&graph); + let expected_colors: DictMap = [ + (NodeIndex::new(0), 0), + (NodeIndex::new(1), 1), + (NodeIndex::new(2), 1), + ] + .into_iter() + .collect(); + assert_eq!(colors, expected_colors); + } +} diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index d1c7e72ef..754535bf5 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -72,6 +72,8 @@ pub type Result = core::result::Result; /// Module for centrality algorithms. pub mod centrality; +/// Module for coloring algorithms. +pub mod coloring; pub mod connectivity; pub mod generators; /// Module for maximum weight matching algorithms. diff --git a/src/coloring.rs b/src/coloring.rs index b432ee1d3..557064658 100644 --- a/src/coloring.rs +++ b/src/coloring.rs @@ -11,21 +11,12 @@ // under the License. use crate::graph; -use rustworkx_core::dictmap::*; - -use hashbrown::{HashMap, HashSet}; -use std::cmp::Reverse; +use rustworkx_core::coloring::greedy_node_color; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::Python; -use petgraph::graph::NodeIndex; -use petgraph::prelude::*; -use petgraph::visit::NodeCount; - -use rayon::prelude::*; - /// Color a :class:`~.PyGraph` object using a greedy graph coloring algorithm. /// /// This function uses a `largest-first` strategy as described in [1]_ and colors @@ -61,35 +52,10 @@ use rayon::prelude::*; #[pyfunction] #[pyo3(text_signature = "(graph, /)")] pub fn graph_greedy_color(py: Python, graph: &graph::PyGraph) -> PyResult { - let mut colors: DictMap = DictMap::new(); - let mut node_vec: Vec = graph.graph.node_indices().collect(); - let mut sort_map: HashMap = HashMap::with_capacity(graph.node_count()); - for k in node_vec.iter() { - sort_map.insert(*k, graph.graph.edges(*k).count()); - } - node_vec.par_sort_by_key(|k| Reverse(sort_map.get(k))); - for u_index in node_vec { - let mut neighbor_colors: HashSet = HashSet::new(); - for edge in graph.graph.edges(u_index) { - let target = edge.target().index(); - let existing_color = match colors.get(&target) { - Some(node) => node, - None => continue, - }; - neighbor_colors.insert(*existing_color); - } - let mut count: usize = 0; - loop { - if !neighbor_colors.contains(&count) { - break; - } - count += 1; - } - colors.insert(u_index.index(), count); - } + let colors = greedy_node_color(&graph.graph); let out_dict = PyDict::new(py); - for (index, color) in colors { - out_dict.set_item(index, color)?; + for (node, color) in colors { + out_dict.set_item(node.index(), color)?; } Ok(out_dict.into()) } From 6cf88551732de843e886a44621730164b4aa1f69 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Wed, 31 May 2023 17:53:27 -0400 Subject: [PATCH 21/37] Add HITS algorithm (#790) * Add the sketch of PageRank * More progress towards pagerank * Use CentralityMapping and FailedToConverge * Finalize PageRank * First test does not run * Remove unwanted triplet * Fix clippy warning * Handle personalization correctly * Add more tests * Cargo fmt * Add scipy to test requirements * Skip SciPy tests in case architecture does not have it * Ignore flake8 errors that do not help * Flake8 * Handle dangling weights * Add more tests * Add nstart argument * Cargo Clippy * Documentation * Fix typo in URL * Add HITS algorithm * Add tests * Documentation * Add documentation entry for HITS * Update releasenotes/notes/add-pagerank-bef0de7d46026071.yaml Co-authored-by: Matthew Treinish * Tweak pyfunction signature * Tweak pyfunction signature * Add scipy to aarch64 test requirements * Address comments from code review * Clippy is always right * Apply suggestions from code review Co-authored-by: Matthew Treinish --------- Co-authored-by: Matthew Treinish --- docs/source/api.rst | 1 + .../notes/add-hits-dec9da09240e8787.yaml | 28 +++ src/lib.rs | 1 + src/link_analysis.rs | 160 +++++++++++++++++- tests/rustworkx_tests/digraph/test_hits.py | 99 +++++++++++ 5 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/add-hits-dec9da09240e8787.yaml create mode 100644 tests/rustworkx_tests/digraph/test_hits.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 3d6fc88c8..b118a69e7 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -67,6 +67,7 @@ Link Analysis :toctree: apiref rustworkx.pagerank + rustworkx.hits .. _traversal: diff --git a/releasenotes/notes/add-hits-dec9da09240e8787.yaml b/releasenotes/notes/add-hits-dec9da09240e8787.yaml new file mode 100644 index 000000000..53b57f5d9 --- /dev/null +++ b/releasenotes/notes/add-hits-dec9da09240e8787.yaml @@ -0,0 +1,28 @@ +--- +features: + - | + Added a new function, :func:`~.hits()` which is used to + compute the hubs and authorities for all nodes in a given directed graph. + For example: + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import mpl_draw + + graph = rx.generators.directed_hexagonal_lattice_graph(2, 2) + hubs, _ = rx.hits(graph) + + # Generate a color list + colors = [] + for node in graph.node_indices(): + hub_score = hubs[node] + graph[node] = hub_score + colors.append(hub_score) + mpl_draw( + graph, + with_labels=True, + node_color=colors, + node_size=650, + labels=lambda x: "{0:.2f}".format(x) + ) diff --git a/src/lib.rs b/src/lib.rs index 2740cc1df..af1525515 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -491,6 +491,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(pagerank))?; + m.add_wrapped(wrap_pyfunction!(hits))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/link_analysis.rs b/src/link_analysis.rs index d38cf0ce3..981f74d34 100644 --- a/src/link_analysis.rs +++ b/src/link_analysis.rs @@ -22,7 +22,7 @@ use crate::{weight_callable, FailedToConverge}; use hashbrown::HashMap; use ndarray::prelude::*; -use ndarray_stats::DeviationExt; +use ndarray_stats::{DeviationExt, QuantileExt}; use petgraph::prelude::*; use petgraph::visit::IntoEdgeReferences; use petgraph::visit::NodeIndexable; @@ -225,3 +225,161 @@ pub fn pagerank( centralities: out_map, }) } + +/// Computes the hubs and authorities in a :class:`~PyDiGraph`. +/// +/// For details on the HITS algorithm, refer to: +/// +/// J. Kleinberg. “Authoritative Sources in a Hyperlinked Environment”. +/// Journal of the ACM, 46 (5), (1999). +/// +/// +/// This function uses a power iteration method to compute the hubs and authorities +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the hubs and authorities. +/// +/// :param PyDiGraph graph: The graph object to run the algorithm on +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified 1.0 will be used as the weight +/// for every edge in ``graph`` +/// :param dict nstart: Optional starting value for the power iteration for each node. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-8 is used. +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 100 is used. +/// :param boolean normalized: If the scores should be normalized (defaults to True). +/// +/// :returns: a tuple of read-only dict-like object whose keys are the node indices. The first value in the tuple +/// contain the hubs scores. The second value contains the authority scores. +/// :rtype: tuple[CentralityMapping, CentralityMapping] +#[pyfunction( + signature = ( + graph, + weight_fn=None, + nstart=None, + tol=1e-6, + max_iter=100, + normalized=true, + ) +)] +#[pyo3( + text_signature = "(graph, /, weight_fn=None, nstart=None, tol=1.0e-8, max_iter=100, normalized=True)" +)] +pub fn hits( + py: Python, + graph: &PyDiGraph, + weight_fn: Option, + nstart: Option>, + tol: f64, + max_iter: usize, + normalized: bool, +) -> PyResult<(CentralityMapping, CentralityMapping)> { + // we use the node bound to make the code work if nodes were removed + let n = graph.graph.node_count(); + let mat_size = graph.graph.node_bound(); + let node_indices: Vec = graph.graph.node_indices().map(|x| x.index()).collect(); + + // Handle empty case + if n == 0 { + return Ok(( + CentralityMapping { + centralities: DictMap::new(), + }, + CentralityMapping { + centralities: DictMap::new(), + }, + )); + } + + // Grab the graph weights from Python to Rust + let mut adjacent: HashMap<(usize, usize), f64> = + HashMap::with_capacity(graph.graph.edge_count()); + let default_weight: f64 = 1.0; + + for edge in graph.graph.edge_references() { + let i = NodeIndexable::to_index(&graph.graph, edge.source()); + let j = NodeIndexable::to_index(&graph.graph, edge.target()); + let weight = edge.weight().clone_ref(py); + + let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?; + + *adjacent.entry((i, j)).or_insert(0.0) += edge_weight; + } + + // Create sparse adjacency matrix and transpose + let mut a = TriMat::new((mat_size, mat_size)); + let mut a_t = TriMat::new((mat_size, mat_size)); + for ((i, j), weight) in adjacent.into_iter() { + a.add_triplet(i, j, weight); + a_t.add_triplet(j, i, weight); + } + let a: CsMat<_> = a.to_csr(); + let a_t: CsMat<_> = a_t.to_csr(); + + // Initial guess of eigenvector of A^T @ A + let mut authority = Array1::::zeros(mat_size); + let default_auth = (n as f64).recip(); + + // Handle custom start + if let Some(nstart) = nstart { + for i in &node_indices { + authority[*i] = *nstart.get(i).unwrap_or(&0.0); + } + let a_sum = authority.sum(); + authority /= a_sum; + } else { + for i in &node_indices { + authority[*i] = default_auth; + } + } + + // Power Method iteration for A^T @ A + let mut has_converged = false; + for _ in 0..max_iter { + // Instead of evaluating A^T @ A, which might not be sparse + // we prefer to calculate A^T (A @ x); A @ x is a vector hence + // we don't have to worry about sparsity + let temp_hub = &a * &authority; + let mut new_authority = &a_t * &temp_hub; + new_authority /= *new_authority.max_skipnan(); + let norm: f64 = new_authority.l1_dist(&authority).unwrap(); + if norm < tol { + has_converged = true; + break; + } else { + authority = new_authority; + } + } + + // Convert to custom return type + if !has_converged { + return Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))); + } + + let mut hubs = &a * &authority; + + if normalized { + hubs /= hubs.sum(); + authority /= authority.sum(); + } + + let hubs_map: DictMap = node_indices.iter().map(|x| (*x, hubs[*x])).collect(); + let auth_map: DictMap = node_indices.iter().map(|x| (*x, authority[*x])).collect(); + + Ok(( + CentralityMapping { + centralities: hubs_map, + }, + CentralityMapping { + centralities: auth_map, + }, + )) +} diff --git a/tests/rustworkx_tests/digraph/test_hits.py b/tests/rustworkx_tests/digraph/test_hits.py new file mode 100644 index 000000000..7ee0b1b69 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_hits.py @@ -0,0 +1,99 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# These tests are adapated from the networkx test cases: +# https://github.com/networkx/networkx/blob/cea310f9066efc0d5ff76f63d33dbc3eefe61f6b/networkx/algorithms/link_analysis/tests/test_pagerank.py + +import unittest + +import rustworkx +import networkx as nx + + +class TestHits(unittest.TestCase): + def setUp(self): + try: + # required for networkx.hits to work + import scipy + + self.assertIsNotNone(scipy.__version__) + except ModuleNotFoundError: + self.skipTest("SciPy is not installed, skipping HITS tests") + + def test_hits(self): + edges = [(0, 2), (0, 4), (1, 0), (2, 4), (4, 3), (4, 2), (5, 4)] + + rx_graph = rustworkx.PyDiGraph() + rx_graph.extend_from_edge_list(edges) + + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(edges) + + rx_h, rx_a = rustworkx.hits(rx_graph) + nx_h, nx_a = nx.hits(nx_graph) + + for v in rx_graph.node_indices(): + self.assertAlmostEqual(rx_h[v], nx_h[v], delta=1.0e-4) + self.assertAlmostEqual(rx_a[v], nx_a[v], delta=1.0e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.directed_path_graph(4) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.hits(graph, max_iter=0) + + def test_normalized(self): + graph = rustworkx.generators.directed_complete_graph(2) + h, a = rustworkx.hits(graph, normalized=False) + self.assertEqual({0: 1, 1: 1}, h) + self.assertEqual({0: 1, 1: 1}, a) + + def test_multi_digraph_versus_weighted(self): + multi_graph = rustworkx.PyDiGraph() + multi_graph.extend_from_edge_list( + [ + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (0, 1), + (1, 0), + (1, 2), + (2, 1), + (1, 2), + (2, 1), + (2, 3), + (3, 2), + (2, 3), + (3, 2), + ] + ) + + weighted_graph = rustworkx.PyDiGraph() + weighted_graph.extend_from_weighted_edge_list( + [(0, 1, 3), (1, 0, 3), (1, 2, 2), (2, 1, 2), (2, 3, 2), (3, 2, 2)] + ) + + h_multi, a_multi = rustworkx.hits(multi_graph, weight_fn=lambda _: 1.0) + h_weight, a_weight = rustworkx.hits(weighted_graph, weight_fn=float) + + for v in multi_graph.node_indices(): + self.assertAlmostEqual(h_multi[v], h_weight[v], delta=1.0e-4) + self.assertAlmostEqual(a_multi[v], a_weight[v], delta=1.0e-4) + + def test_nstart(self): + graph = rustworkx.generators.directed_complete_graph(10) + nstart = {5: 1, 6: 1} # this guess is worse than the uniform guess =) + h, a = rustworkx.hits(graph, nstart=nstart) + + for v in graph.node_indices(): + self.assertAlmostEqual(h[v], 1 / 10.0, delta=1.0e-4) + self.assertAlmostEqual(a[v], 1 / 10.0, delta=1.0e-4) From 248256341c35cb4df46d74acb4780714d5c89291 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> Date: Thu, 1 Jun 2023 07:49:55 -0400 Subject: [PATCH 22/37] Add Katz Centrality (#797) * Start Katz centrality * More Katz centrality * Almost done with Katz * Add release note * Add docs to rustworkx-core * Add documentation and rustworkx-core tests * Fix max iter * Documentation details * Tweak signature * Apply suggestions from code review Co-authored-by: Matthew Treinish * Suggestion from code review * Code review suggestions --------- Co-authored-by: Matthew Treinish --- docs/source/api.rst | 3 + .../notes/add-katz-5389c6e5bd30e176.yaml | 33 +++ rustworkx-core/src/centrality.rs | 241 ++++++++++++++++ rustworkx/__init__.py | 76 +++++ src/centrality.rs | 266 ++++++++++++++++++ src/lib.rs | 2 + .../digraph/test_centrality.py | 46 +++ .../rustworkx_tests/graph/test_centrality.py | 42 +++ 8 files changed, 709 insertions(+) create mode 100644 releasenotes/notes/add-katz-5389c6e5bd30e176.yaml diff --git a/docs/source/api.rst b/docs/source/api.rst index b118a69e7..2eb378e5b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -56,6 +56,7 @@ Centrality rustworkx.betweenness_centrality rustworkx.edge_betweenness_centrality rustworkx.eigenvector_centrality + rustworkx.katz_centrality rustworkx.closeness_centrality .. _link-analysis: @@ -338,6 +339,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_edge_betweenness_centrality rustworkx.digraph_closeness_centrality rustworkx.digraph_eigenvector_centrality + rustworkx.digraph_katz_centrality rustworkx.digraph_unweighted_average_shortest_path_length rustworkx.digraph_bfs_search rustworkx.digraph_dijkstra_search @@ -396,6 +398,7 @@ typed API based on the data type. rustworkx.graph_edge_betweenness_centrality rustworkx.graph_closeness_centrality rustworkx.graph_eigenvector_centrality + rustworkx.graph_katz_centrality rustworkx.graph_unweighted_average_shortest_path_length rustworkx.graph_bfs_search rustworkx.graph_dijkstra_search diff --git a/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml b/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml new file mode 100644 index 000000000..a10e4dc7c --- /dev/null +++ b/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml @@ -0,0 +1,33 @@ +--- +features: + - | + Added a new function, :func:`~.katz_centrality()` which is used to + compute the Katz centrality for all nodes in a given graph. For + example: + + .. jupyter-execute:: + + import rustworkx as rx + from rustworkx.visualization import mpl_draw + + graph = rx.generators.hexagonal_lattice_graph(4, 4) + centrality = rx.katz_centrality(graph) + + # Generate a color list + colors = [] + for node in graph.node_indices(): + centrality_score = centrality[node] + graph[node] = centrality_score + colors.append(centrality_score) + mpl_draw( + graph, + with_labels=True, + node_color=colors, + node_size=650, + labels=lambda x: "{0:.2f}".format(x) + ) + + - | + Added a new function to rustworkx-core ``katz_centrality`` to the + ``rustworkx_core::centrality`` modules which is used to compute the + Katz centrality for all nodes in a given graph. diff --git a/rustworkx-core/src/centrality.rs b/rustworkx-core/src/centrality.rs index 2866a6d0e..c099c6304 100644 --- a/rustworkx-core/src/centrality.rs +++ b/rustworkx-core/src/centrality.rs @@ -672,6 +672,133 @@ where Ok(None) } +/// Compute the Katz centrality of a graph +/// +/// For details on the Katz centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// [`katz_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html) +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the eigenvector centrality. +/// +/// Arguments: +/// +/// * `graph` - The graph object to run the algorithm on +/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for +/// an edge in the graph and is expected to return a `Result` of +/// the weight of that edge. +/// * `alpha` - Attenuation factor. If set to `None`, a default value of 0.1 is used. +/// * `beta_map` - Immediate neighbourhood weights. Must contain all node indices or be `None`. +/// * `beta_scalar` - Immediate neighbourhood scalar that replaces `beta_map` in case `beta_map` is None. +/// Defaults to 1.0 in case `None` is provided. +/// * `max_iter` - The maximum number of iterations in the power method. If +/// set to `None` a default value of 100 is used. +/// * `tol` - The error tolerance used when checking for convergence in the +/// power method. If set to `None` a default value of 1e-6 is used. +/// +/// # Example +/// ```rust +/// use rustworkx_core::Result; +/// use rustworkx_core::petgraph; +/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers}; +/// use rustworkx_core::centrality::katz_centrality; +/// +/// let g = petgraph::graph::UnGraph::::from_edges(&[ +/// (0, 1), (1, 2) +/// ]); +/// // Calculate the eigenvector centrality +/// let output: Result>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None); +/// let centralities = output.unwrap().unwrap(); +/// assert!(centralities[1] > centralities[0], "Node 1 is more central than node 0"); +/// assert!(centralities[1] > centralities[2], "Node 1 is more central than node 2"); +/// ``` +pub fn katz_centrality( + graph: G, + mut weight_fn: F, + alpha: Option, + beta_map: Option>, + beta_scalar: Option, + max_iter: Option, + tol: Option, +) -> Result>, E> +where + G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount, + G::NodeId: Eq + std::hash::Hash, + F: FnMut(G::EdgeRef) -> Result, +{ + let alpha: f64 = alpha.unwrap_or(0.1); + + let mut beta: HashMap = beta_map.unwrap_or_else(HashMap::new); + + if beta.is_empty() { + // beta_map was none + // populate hashmap with default value + let beta_scalar = beta_scalar.unwrap_or(1.0); + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + beta.insert(node, beta_scalar); + } + } else { + // Check if beta contains all node indices + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + if !beta.contains_key(&node) { + return Ok(None); // beta_map was provided but did not include all nodes + } + } + } + + let tol: f64 = tol.unwrap_or(1e-6); + let max_iter = max_iter.unwrap_or(1000); + + let mut x: Vec = vec![0.; graph.node_bound()]; + let node_count = graph.node_count(); + for _ in 0..max_iter { + let x_last = x.clone(); + x = vec![0.; graph.node_bound()]; + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + for edge in graph.edges(node_index) { + let w = weight_fn(edge)?; + let neighbor = edge.target(); + x[graph.to_index(neighbor)] += x_last[node] * w; + } + } + for node_index in graph.node_identifiers() { + let node = graph.to_index(node_index); + x[node] = alpha * x[node] + beta.get(&node).unwrap_or(&0.0); + } + if (0..x.len()) + .map(|node| (x[node] - x_last[node]).abs()) + .sum::() + < node_count as f64 * tol + { + // Normalize vector + let norm: f64 = x.iter().map(|val| val.powi(2)).sum::().sqrt(); + if norm == 0. { + return Ok(None); + } + for v in x.iter_mut() { + *v /= norm; + } + + return Ok(Some(x)); + } + } + + Ok(None) +} + #[cfg(test)] mod test_eigenvector_centrality { @@ -761,6 +888,120 @@ mod test_eigenvector_centrality { } } +#[cfg(test)] +mod test_katz_centrality { + + use crate::centrality::katz_centrality; + use crate::petgraph; + use crate::Result; + use hashbrown::HashMap; + + macro_rules! assert_almost_equal { + ($x:expr, $y:expr, $d:expr) => { + if ($x - $y).abs() >= $d { + panic!("{} != {} within delta of {}", $x, $y, $d); + } + }; + } + #[test] + fn test_no_convergence() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, None, None, Some(0), None); + let result = output.unwrap(); + assert_eq!(None, result); + } + + #[test] + fn test_incomplete_beta() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let beta_map: HashMap = [(0, 1.0)].iter().cloned().collect(); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None); + let result = output.unwrap(); + assert_eq!(None, result); + } + + #[test] + fn test_complete_beta() { + let g = petgraph::graph::UnGraph::::from_edges(&[(0, 1), (1, 2)]); + let beta_map: HashMap = + [(0, 0.5), (1, 1.0), (2, 0.5)].iter().cloned().collect(); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None); + let result = output.unwrap().unwrap(); + let expected_values: Vec = + vec![0.4318894504492167, 0.791797325823564, 0.4318894504492167]; + for i in 0..3 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } + + #[test] + fn test_undirected_complete_graph() { + let g = petgraph::graph::UnGraph::::from_edges([ + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + (3, 4), + ]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), Some(0.2), None, Some(1.1), None, None); + let result = output.unwrap().unwrap(); + let expected_value: f64 = (1_f64 / 5_f64).sqrt(); + let expected_values: Vec = vec![expected_value; 5]; + for i in 0..5 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } + + #[test] + fn test_directed_graph() { + let g = petgraph::graph::DiGraph::::from_edges([ + (0, 1), + (0, 2), + (1, 3), + (2, 1), + (2, 4), + (3, 1), + (3, 4), + (3, 5), + (4, 5), + (4, 6), + (4, 7), + (5, 7), + (6, 0), + (6, 4), + (6, 7), + (7, 5), + (7, 6), + ]); + let output: Result>> = + katz_centrality(&g, |_| Ok(1.), None, None, None, None, None); + let result = output.unwrap().unwrap(); + let expected_values: Vec = vec![ + 0.3135463087489011, + 0.3719056758615039, + 0.3094350787809586, + 0.31527101632646026, + 0.3760169058294464, + 0.38618584417917906, + 0.35465874858087904, + 0.38976653416801743, + ]; + + for i in 0..8 { + assert_almost_equal!(expected_values[i], result[i], 1e-4); + } + } +} + /// Compute the closeness centrality of each node in the graph. /// /// The closeness centrality of a node `u` is the reciprocal of the average diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index c1c31750f..fba037d84 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -1776,6 +1776,82 @@ def _graph_eigenvector_centrality( ) +@functools.singledispatch +def katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=100, tol=1e-6 +): + """Compute the Katz centrality of a graph. + + For details on the Katz centrality refer to: + + Leo Katz. “A New Status Index Derived from Sociometric Index.” + Psychometrika 18(1):39–43, 1953 + + + This function uses a power iteration method to compute the eigenvector + and convergence is not guaranteed. The function will stop when `max_iter` + iterations is reached or when the computed vector between two iterations + is smaller than the error tolerance multiplied by the number of nodes. + The implementation of this algorithm is based on the NetworkX + `katz_centrality() `__ + function. + + In the case of multigraphs the weights of any parallel edges will be + summed when computing the Katz centrality. + + :param graph: Graph to be used. Can either be a + :class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`. + :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. + :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood + weight is used for all nodes. If a dictionary is provided, it must contain all node indices. + If beta is not specified, a default value of 1.0 is used. + :param weight_fn: An optional input callable that will be passed the edge's + payload object and is expected to return a `float` weight for that edge. + If this is not specified ``default_weight`` will be used as the weight + for every edge in ``graph`` + :param float default_weight: If ``weight_fn`` is not set the default weight + value to use for the weight of all edges + :param int max_iter: The maximum number of iterations in the power method. If + not specified a default value of 100 is used. + :param float tol: The error tolerance used when checking for convergence in the + power method. If this is not specified default value of 1e-6 is used. + + :returns: a read-only dict-like object whose keys are the node indices and values are the + centrality score for that node. + :rtype: CentralityMapping + """ + + +@katz_centrality.register(PyDiGraph) +def _digraph_katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6 +): + return digraph_katz_centrality( + graph, + alpha=alpha, + beta=beta, + weight_fn=weight_fn, + default_weight=default_weight, + max_iter=max_iter, + tol=tol, + ) + + +@katz_centrality.register(PyGraph) +def _graph_katz_centrality( + graph, alpha=0.1, beta=1.0, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6 +): + return graph_katz_centrality( + graph, + alpha=alpha, + beta=beta, + weight_fn=weight_fn, + default_weight=default_weight, + max_iter=max_iter, + tol=tol, + ) + + @functools.singledispatch def vf2_mapping( first, diff --git a/src/centrality.rs b/src/centrality.rs index 553fe9f97..ca055cec3 100644 --- a/src/centrality.rs +++ b/src/centrality.rs @@ -10,6 +10,8 @@ // License for the specific language governing permissions and limitations // under the License. +#![allow(clippy::too_many_arguments)] + use std::convert::TryFrom; use crate::digraph; @@ -18,9 +20,12 @@ use crate::iterators::{CentralityMapping, EdgeCentralityMapping}; use crate::CostFn; use crate::FailedToConverge; +use hashbrown::HashMap; use petgraph::graph::NodeIndex; use petgraph::visit::EdgeIndexable; use petgraph::visit::EdgeRef; +use petgraph::visit::IntoNodeIdentifiers; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use rustworkx_core::centrality; @@ -557,3 +562,264 @@ pub fn digraph_eigenvector_centrality( ))), } } + +/// Compute the Katz centrality of a :class:`~PyGraph`. +/// +/// For details on the Katz centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// `katz_centrality() `__ +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the Katz centrality. +/// +/// :param PyGraph graph: The graph object to run the algorithm on +/// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. +/// :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood +/// weight is used for all nodes. If a dictionary is provided, it must contain all node indices. +/// If beta is not specified, a default value of 1.0 is used. +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified ``default_weight`` will be used as the weight +/// for every edge in ``graph`` +/// :param float default_weight: If ``weight_fn`` is not set the default weight +/// value to use for the weight of all edges +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 1000 is used. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-6 is used. +/// +/// :returns: a read-only dict-like object whose keys are the node indices and values are the +/// centrality score for that node. +/// :rtype: CentralityMapping +#[pyfunction( + signature = ( + graph, + alpha=0.1, + beta=None, + weight_fn=None, + default_weight=1.0, + max_iter=1000, + tol=1e-6 + ) +)] +#[pyo3( + text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" +)] +pub fn graph_katz_centrality( + py: Python, + graph: &graph::PyGraph, + alpha: f64, + beta: Option, + weight_fn: Option, + default_weight: f64, + max_iter: usize, + tol: f64, +) -> PyResult { + let mut edge_weights = vec![default_weight; graph.graph.edge_bound()]; + if weight_fn.is_some() { + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + for edge in graph.graph.edge_indices() { + edge_weights[edge.index()] = + cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; + } + } + + let mut beta_map: HashMap = HashMap::new(); + + if let Some(beta) = beta { + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + + for node_index in graph.graph.node_identifiers() { + if !beta_map.contains_key(&node_index.index()) { + return Err(PyValueError::new_err( + "Beta does not contain all node indices", + )); + } + } + } + } + } else { + // Populate with 1.0 + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), 1.0); + } + } + + let ev_centrality = centrality::katz_centrality( + &graph.graph, + |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, + Some(alpha), + Some(beta_map), + None, + Some(max_iter), + Some(tol), + )?; + match ev_centrality { + Some(centrality) => Ok(CentralityMapping { + centralities: centrality + .iter() + .enumerate() + .filter_map(|(k, v)| { + if graph.graph.contains_node(NodeIndex::new(k)) { + Some((k, *v)) + } else { + None + } + }) + .collect(), + }), + None => Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))), + } +} + +/// Compute the Katz centrality of a :class:`~PyDiGraph`. +/// +/// For details on the Katz centrality refer to: +/// +/// Leo Katz. “A New Status Index Derived from Sociometric Index.” +/// Psychometrika 18(1):39–43, 1953 +/// +/// +/// This function uses a power iteration method to compute the eigenvector +/// and convergence is not guaranteed. The function will stop when `max_iter` +/// iterations is reached or when the computed vector between two iterations +/// is smaller than the error tolerance multiplied by the number of nodes. +/// The implementation of this algorithm is based on the NetworkX +/// `katz_centrality() `__ +/// function. +/// +/// In the case of multigraphs the weights of any parallel edges will be +/// summed when computing the Katz centrality. +/// +/// :param PyDiGraph graph: The graph object to run the algorithm on +/// :param float alpha: Attenuation factor. If this is not specified default value of 0.1 is used. +/// :param float | dict beta: Immediate neighbourhood weights. If a float is provided, the neighbourhood +/// weight is used for all nodes. If a dictionary is provided, it must contain all node indices. +/// If beta is not specified, a default value of 1.0 is used. +/// :param weight_fn: An optional input callable that will be passed the edge's +/// payload object and is expected to return a `float` weight for that edge. +/// If this is not specified ``default_weight`` will be used as the weight +/// for every edge in ``graph`` +/// :param float default_weight: If ``weight_fn`` is not set the default weight +/// value to use for the weight of all edges +/// :param int max_iter: The maximum number of iterations in the power method. If +/// not specified a default value of 1000 is used. +/// :param float tol: The error tolerance used when checking for convergence in the +/// power method. If this is not specified default value of 1e-6 is used. +/// +/// :returns: a read-only dict-like object whose keys are the node indices and values are the +/// centrality score for that node. +/// :rtype: CentralityMapping +#[pyfunction( + signature = ( + graph, + alpha=0.1, + beta=None, + weight_fn=None, + default_weight=1.0, + max_iter=1000, + tol=1e-6 + ) +)] +#[pyo3( + text_signature = "(graph, /, alpha=0.1, beta=None, weight_fn=None, default_weight=1.0, max_iter=1000, tol=1e-6)" +)] +pub fn digraph_katz_centrality( + py: Python, + graph: &digraph::PyDiGraph, + alpha: f64, + beta: Option, + weight_fn: Option, + default_weight: f64, + max_iter: usize, + tol: f64, +) -> PyResult { + let mut edge_weights = vec![default_weight; graph.graph.edge_bound()]; + if weight_fn.is_some() { + let cost_fn = CostFn::try_from((weight_fn, default_weight))?; + for edge in graph.graph.edge_indices() { + edge_weights[edge.index()] = + cost_fn.call(py, graph.graph.edge_weight(edge).unwrap())?; + } + } + + let mut beta_map: HashMap = HashMap::new(); + + if let Some(beta) = beta { + match beta.extract::(py) { + Ok(beta_scalar) => { + // User provided a scalar, populate beta_map with the value + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), beta_scalar); + } + } + Err(_) => { + beta_map = beta.extract::>(py)?; + + for node_index in graph.graph.node_identifiers() { + if !beta_map.contains_key(&node_index.index()) { + return Err(PyValueError::new_err( + "Beta does not contain all node indices", + )); + } + } + } + } + } else { + // Populate with 1.0 + for node_index in graph.graph.node_identifiers() { + beta_map.insert(node_index.index(), 1.0); + } + } + + let ev_centrality = centrality::katz_centrality( + &graph.graph, + |e| -> PyResult { Ok(edge_weights[e.id().index()]) }, + Some(alpha), + Some(beta_map), + None, + Some(max_iter), + Some(tol), + )?; + + match ev_centrality { + Some(centrality) => Ok(CentralityMapping { + centralities: centrality + .iter() + .enumerate() + .filter_map(|(k, v)| { + if graph.graph.contains_node(NodeIndex::new(k)) { + Some((k, *v)) + } else { + None + } + }) + .collect(), + }), + None => Err(FailedToConverge::new_err(format!( + "Function failed to converge on a solution in {} iterations", + max_iter + ))), + } +} diff --git a/src/lib.rs b/src/lib.rs index af1525515..3eb9ac713 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -428,6 +428,8 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_edge_betweenness_centrality))?; m.add_wrapped(wrap_pyfunction!(graph_eigenvector_centrality))?; m.add_wrapped(wrap_pyfunction!(digraph_eigenvector_centrality))?; + m.add_wrapped(wrap_pyfunction!(graph_katz_centrality))?; + m.add_wrapped(wrap_pyfunction!(digraph_katz_centrality))?; m.add_wrapped(wrap_pyfunction!(graph_astar_shortest_path))?; m.add_wrapped(wrap_pyfunction!(digraph_astar_shortest_path))?; m.add_wrapped(wrap_pyfunction!(graph_greedy_color))?; diff --git a/tests/rustworkx_tests/digraph/test_centrality.py b/tests/rustworkx_tests/digraph/test_centrality.py index b97a8527e..3ab12e465 100644 --- a/tests/rustworkx_tests/digraph/test_centrality.py +++ b/tests/rustworkx_tests/digraph/test_centrality.py @@ -14,6 +14,7 @@ import unittest import rustworkx +import networkx as nx class TestCentralityDiGraph(unittest.TestCase): @@ -162,6 +163,51 @@ def test_no_convergence(self): rustworkx.eigenvector_centrality(graph, max_iter=0) +class TestKatzCentrality(unittest.TestCase): + def test_complete_graph(self): + graph = rustworkx.generators.directed_complete_graph(5) + centrality = rustworkx.digraph_katz_centrality(graph) + expected_value = math.sqrt(1.0 / 5.0) + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.directed_complete_graph(5) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.katz_centrality(graph, max_iter=0) + + def test_beta_scalar(self): + rx_graph = rustworkx.generators.directed_grid_graph(5, 2) + beta = 0.3 + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_dictionary(self): + rx_graph = rustworkx.generators.directed_grid_graph(5, 2) + beta = {i: 0.1 * i**2 for i in range(10)} + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.DiGraph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_incomplete(self): + graph = rustworkx.generators.directed_grid_graph(5, 2) + with self.assertRaises(ValueError): + rustworkx.katz_centrality(graph, beta={0: 0.25}) + + class TestEdgeBetweennessCentrality(unittest.TestCase): def test_complete_graph(self): graph = rustworkx.generators.directed_mesh_graph(5) diff --git a/tests/rustworkx_tests/graph/test_centrality.py b/tests/rustworkx_tests/graph/test_centrality.py index 31e88ed7e..12ed67457 100644 --- a/tests/rustworkx_tests/graph/test_centrality.py +++ b/tests/rustworkx_tests/graph/test_centrality.py @@ -14,6 +14,7 @@ import unittest import rustworkx +import networkx as nx class TestCentralityGraph(unittest.TestCase): @@ -133,6 +134,47 @@ def test_no_convergence(self): rustworkx.eigenvector_centrality(graph, max_iter=0) +class TestKatzCentrality(unittest.TestCase): + def test_complete_graph(self): + graph = rustworkx.generators.complete_graph(5) + centrality = rustworkx.graph_katz_centrality(graph) + expected_value = math.sqrt(1.0 / 5.0) + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_no_convergence(self): + graph = rustworkx.generators.complete_graph(5) + with self.assertRaises(rustworkx.FailedToConverge): + rustworkx.katz_centrality(graph, max_iter=0) + + def test_beta_scalar(self): + graph = rustworkx.generators.generalized_petersen_graph(5, 2) + expected_value = 0.31622776601683794 + + centrality = rustworkx.katz_centrality(graph, alpha=0.1, beta=0.1, tol=1e-8) + + for value in centrality.values(): + self.assertAlmostEqual(value, expected_value, delta=1e-4) + + def test_beta_dictionary(self): + rx_graph = rustworkx.generators.generalized_petersen_graph(5, 2) + beta = {i: 0.1 * i**2 for i in range(10)} + + rx_centrality = rustworkx.katz_centrality(rx_graph, alpha=0.25, beta=beta) + + nx_graph = nx.Graph() + nx_graph.add_edges_from(rx_graph.edge_list()) + nx_centrality = nx.katz_centrality(nx_graph, alpha=0.25, beta=beta) + + for key in rx_centrality.keys(): + self.assertAlmostEqual(rx_centrality[key], nx_centrality[key], delta=1e-4) + + def test_beta_incomplete(self): + graph = rustworkx.generators.generalized_petersen_graph(5, 2) + with self.assertRaises(ValueError): + rustworkx.katz_centrality(graph, beta={0: 0.25}) + + class TestEdgeBetweennessCentrality(unittest.TestCase): def test_complete_graph(self): graph = rustworkx.generators.mesh_graph(5) From 6982d59f914f6a2397a6f54e66ddd802a31e92b5 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 1 Jun 2023 09:52:24 -0400 Subject: [PATCH 23/37] Bump pyo3 and rust numpy version to 0.19.0 (#887) * Bump pyo3 and rust numpy version to 0.19.0 Pyo3 0.19.0 and rust-numpy 0.19.0 were just released. This commit updates the version used in rustworkx to these latest releases. At the same time this updates usage of text signature for classes that was deprecated in 0.19.0 release. While not fatal for normal builds this would have failed clippy in CI because we treat warnings as errors. * Make _dt optional in iterators --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --- Cargo.lock | 30 +++++++++++++++--------------- Cargo.toml | 4 ++-- rustworkx/iterators.pyi | 3 ++- src/digraph.rs | 3 +-- src/graph.rs | 3 +-- src/toposort.rs | 3 +-- 6 files changed, 22 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e0ed61b78..14bbda764 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -228,9 +228,9 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] @@ -335,9 +335,9 @@ dependencies = [ [[package]] name = "numpy" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b0fee4571867d318651c24f4a570c3f18408cf95f16ccb576b3ce85496a46e" +checksum = "437213adf41bbccf4aeae535fbfcdad0f6fed241e1ae182ebe97fa1f3ce19389" dependencies = [ "libc", "ndarray", @@ -414,16 +414,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.18.3" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" +checksum = "cffef52f74ec3b1a1baf295d9b8fcc3070327aefc39a6d00656b13c1d0b8885c" dependencies = [ "cfg-if", "hashbrown", "indexmap", "indoc", "libc", - "memoffset 0.8.0", + "memoffset 0.9.0", "num-bigint", "num-complex 0.4.3", "parking_lot", @@ -435,9 +435,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.18.3" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" +checksum = "713eccf888fb05f1a96eb78c0dbc51907fee42b3377272dc902eb38985f418d5" dependencies = [ "once_cell", "target-lexicon", @@ -445,9 +445,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.18.3" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" +checksum = "5b2ecbdcfb01cbbf56e179ce969a048fd7305a66d4cdf3303e0da09d69afe4c3" dependencies = [ "libc", "pyo3-build-config", @@ -455,9 +455,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.18.3" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" +checksum = "b78fdc0899f2ea781c463679b20cb08af9247febc8d052de941951024cd8aea0" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -467,9 +467,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.18.3" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" +checksum = "60da7b84f1227c3e2fe7593505de274dcf4c8928b4e0a1c23d551a14e4e80a0f" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 0a7fafa3d..7894eecfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ crate-type = ["cdylib"] ahash = "0.8.0" petgraph = "0.6.3" fixedbitset = "0.4.2" -numpy = "0.18.0" +numpy = "0.19.0" rand = "0.8" rand_pcg = "0.3" rayon = "1.6" @@ -36,7 +36,7 @@ serde_json = "1.0" rustworkx-core = { path = "rustworkx-core", version = "=0.13.0" } [dependencies.pyo3] -version = "0.18.3" +version = "0.19.0" features = ["extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"] [dependencies.hashbrown] diff --git a/rustworkx/iterators.pyi b/rustworkx/iterators.pyi index 81840d84b..e0d32c2f8 100644 --- a/rustworkx/iterators.pyi +++ b/rustworkx/iterators.pyi @@ -23,6 +23,7 @@ from typing import ( Tuple, overload, final, + Optional, ) from abc import ABC from collections.abc import Sequence @@ -68,7 +69,7 @@ class RustworkxCustomVecIter(Generic[T_co], Sequence[T_co], ABC): def __len__(self) -> int: ... def __ne__(self, other: object) -> bool: ... def __setstate__(self, state: Sequence[T_co]) -> None: ... - def __array__(self, _dt: np.dtype = ...) -> np.ndarray: ... + def __array__(self, _dt: Optional[np.dtype] = ...) -> np.ndarray: ... class RustworkxCustomHashMapIter(Generic[S, T_co], Mapping[S, T_co], ABC): def __init__(self) -> None: ... diff --git a/src/digraph.rs b/src/digraph.rs index a0d63b09c..353da8504 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -173,7 +173,6 @@ use super::dag_algo::is_directed_acyclic_graph; /// :attr:`~.PyDiGraph.attrs` attribute. This can be any Python object. If /// it is not specified :attr:`~.PyDiGraph.attrs` will be set to ``None``. #[pyclass(mapping, module = "rustworkx", subclass)] -#[pyo3(text_signature = "(/, check_cycle=False, multigraph=True, attrs=None)")] #[derive(Clone)] pub struct PyDiGraph { pub graph: StablePyGraph, @@ -285,7 +284,7 @@ impl PyDiGraph { #[pymethods] impl PyDiGraph { #[new] - #[pyo3(signature=(check_cycle=false, multigraph=true, attrs=None))] + #[pyo3(signature=(check_cycle=false, multigraph=true, attrs=None), text_signature="(/, check_cycle=False, multigraph=True, attrs=None)")] fn new(py: Python, check_cycle: bool, multigraph: bool, attrs: Option) -> Self { PyDiGraph { graph: StablePyGraph::::new(), diff --git a/src/graph.rs b/src/graph.rs index d57fbb321..75165cc4c 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -136,7 +136,6 @@ use petgraph::visit::{ /// :attr:`~.PyGraph.attrs` attribute. This can be any Python object. If /// it is not specified :attr:`~.PyGraph.attrs` will be set to ``None``. #[pyclass(mapping, module = "rustworkx", subclass)] -#[pyo3(text_signature = "(/, multigraph=True, attrs=None)")] #[derive(Clone)] pub struct PyGraph { pub graph: StablePyGraph, @@ -181,7 +180,7 @@ impl PyGraph { #[pymethods] impl PyGraph { #[new] - #[pyo3(signature=(multigraph=true, attrs=None))] + #[pyo3(signature=(multigraph=true, attrs=None), text_signature = "(/, multigraph=True, attrs=None)")] fn new(py: Python, multigraph: bool, attrs: Option) -> Self { PyGraph { graph: StablePyGraph::::default(), diff --git a/src/toposort.rs b/src/toposort.rs index 7dcb1932a..e91b45314 100644 --- a/src/toposort.rs +++ b/src/toposort.rs @@ -62,7 +62,6 @@ enum NodeState { /// it's set to ``False``, topological sorter will output as many nodes /// as possible until cycles block more progress. By default is ``True``. #[pyclass(module = "rustworkx")] -#[pyo3(text_signature = "(graph, /, check_cycle=True)")] pub struct TopologicalSorter { dag: Py, ready_nodes: Vec, @@ -75,7 +74,7 @@ pub struct TopologicalSorter { #[pymethods] impl TopologicalSorter { #[new] - #[pyo3(signature=(dag, check_cycle=true))] + #[pyo3(signature=(dag, check_cycle=true), text_signature = "(graph, /, check_cycle=True)")] fn new(py: Python, dag: Py, check_cycle: bool) -> PyResult { { let dag = &dag.borrow(py); From 94d5673c672da568bfb6ea501212f785aac8ab26 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 5 Jun 2023 08:30:51 -0400 Subject: [PATCH 24/37] Prepare 0.13.0 release (#881) * Prepare 0.13.0 release This commit prepares the 0.13.0 release. It moves the release notes that for the release to a separate directory and adds a prelude to the release notes. This should be the last commit to merge and after it merges it'll that commit will be tagged as the 0.13.0 release. * Update release notes * More release note updates --- releasenotes/config.yaml | 1 + ...graph-make-symmetric-60d0287a7f7eec04.yaml | 0 ...tweenness-centrality-8de06bf716caece0.yaml | 2 +- ...ycle-and-cycle-basis-399d487def06239d.yaml | 0 .../{ => 0.13}/add-hits-dec9da09240e8787.yaml | 0 .../{ => 0.13}/add-katz-5389c6e5bd30e176.yaml | 0 ...-longest-simple-path-afdc4538c49bc38f.yaml | 0 .../add-pagerank-bef0de7d46026071.yaml | 0 ...dd-random-generators-9f99e57b5e4188f2.yaml | 0 ...add_bfs_predecessors-751b468d1b3c01de.yaml | 0 ...ecessor_node_by_edge-fcd525aff055c5fb.yaml | 0 .../added-token-swapper-bd168eeb5a31bd99.yaml | 21 +++++++++++++++++++ .../bump-msrv-4c746d8fbd91ea7c.yaml | 2 +- ...closeness-centrality-459c5c7e35cb2e63.yaml | 0 .../empty-and-complete-09d569b42cb6b9d5.yaml | 0 ...-check-cycle-on-copy-d060a1781976f728.yaml | 7 +++++++ ...-edge-indices-pickle-83fddf149441fa9f.yaml | 0 ...riority-queue-compat-48c22f64a9208812.yaml | 0 ...ix-sequence-protocol-e95246e864cc850a.yaml | 0 ...odule-rustworkx-core-c314df0b8a42aab4.yaml | 0 .../graph-annotations-1d436930bf60c5c2.yaml | 7 +++++++ ...le-non-existent-edge-15d70cfe60c89ac2.yaml | 0 ...migrate-greedy-color-c3239f35840eec18.yaml | 0 .../move-core-number-48efa7ef968b6736.yaml | 2 +- .../0.13/prepare-0.13.0-5e579fb3ab1e3b60.yaml | 17 +++++++++++++++ .../added-token-swapper-bd168eeb5a31bd99.yaml | 6 ------ ...-check-cycle-on-copy-d060a1781976f728.yaml | 5 ----- .../graph-annotations-1d436930bf60c5c2.yaml | 6 ------ 28 files changed, 56 insertions(+), 20 deletions(-) rename releasenotes/notes/{ => 0.13}/add-digraph-make-symmetric-60d0287a7f7eec04.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-edge-betweenness-centrality-8de06bf716caece0.yaml (99%) rename releasenotes/notes/{ => 0.13}/add-find-cycle-and-cycle-basis-399d487def06239d.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-hits-dec9da09240e8787.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-katz-5389c6e5bd30e176.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-longest-simple-path-afdc4538c49bc38f.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-pagerank-bef0de7d46026071.yaml (100%) rename releasenotes/notes/{ => 0.13}/add-random-generators-9f99e57b5e4188f2.yaml (100%) rename releasenotes/notes/{ => 0.13}/add_bfs_predecessors-751b468d1b3c01de.yaml (100%) rename releasenotes/notes/{ => 0.13}/add_find_predecessor_node_by_edge-fcd525aff055c5fb.yaml (100%) create mode 100644 releasenotes/notes/0.13/added-token-swapper-bd168eeb5a31bd99.yaml rename releasenotes/notes/{ => 0.13}/bump-msrv-4c746d8fbd91ea7c.yaml (74%) rename releasenotes/notes/{ => 0.13}/closeness-centrality-459c5c7e35cb2e63.yaml (100%) rename releasenotes/notes/{ => 0.13}/empty-and-complete-09d569b42cb6b9d5.yaml (100%) create mode 100644 releasenotes/notes/0.13/fix-check-cycle-on-copy-d060a1781976f728.yaml rename releasenotes/notes/{ => 0.13}/fix-edge-indices-pickle-83fddf149441fa9f.yaml (100%) rename releasenotes/notes/{ => 0.13}/fix-priority-queue-compat-48c22f64a9208812.yaml (100%) rename releasenotes/notes/{ => 0.13}/fix-sequence-protocol-e95246e864cc850a.yaml (100%) rename releasenotes/notes/{ => 0.13}/generators-module-rustworkx-core-c314df0b8a42aab4.yaml (100%) create mode 100644 releasenotes/notes/0.13/graph-annotations-1d436930bf60c5c2.yaml rename releasenotes/notes/{ => 0.13}/handle-non-existent-edge-15d70cfe60c89ac2.yaml (100%) rename releasenotes/notes/{ => 0.13}/migrate-greedy-color-c3239f35840eec18.yaml (100%) rename releasenotes/notes/{ => 0.13}/move-core-number-48efa7ef968b6736.yaml (62%) create mode 100644 releasenotes/notes/0.13/prepare-0.13.0-5e579fb3ab1e3b60.yaml delete mode 100644 releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml delete mode 100644 releasenotes/notes/fix-check-cycle-on-copy-d060a1781976f728.yaml delete mode 100644 releasenotes/notes/graph-annotations-1d436930bf60c5c2.yaml diff --git a/releasenotes/config.yaml b/releasenotes/config.yaml index 141d552cd..0763ea8f8 100644 --- a/releasenotes/config.yaml +++ b/releasenotes/config.yaml @@ -1,3 +1,4 @@ --- encoding: utf8 default_branch: main +earliest_version: 0.8.0 diff --git a/releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml b/releasenotes/notes/0.13/add-digraph-make-symmetric-60d0287a7f7eec04.yaml similarity index 100% rename from releasenotes/notes/add-digraph-make-symmetric-60d0287a7f7eec04.yaml rename to releasenotes/notes/0.13/add-digraph-make-symmetric-60d0287a7f7eec04.yaml diff --git a/releasenotes/notes/add-edge-betweenness-centrality-8de06bf716caece0.yaml b/releasenotes/notes/0.13/add-edge-betweenness-centrality-8de06bf716caece0.yaml similarity index 99% rename from releasenotes/notes/add-edge-betweenness-centrality-8de06bf716caece0.yaml rename to releasenotes/notes/0.13/add-edge-betweenness-centrality-8de06bf716caece0.yaml index 81a002d1c..9ae874a42 100644 --- a/releasenotes/notes/add-edge-betweenness-centrality-8de06bf716caece0.yaml +++ b/releasenotes/notes/0.13/add-edge-betweenness-centrality-8de06bf716caece0.yaml @@ -8,7 +8,7 @@ features: Ulrik Brandes, On Variants of Shortest-Path Betweenness Centrality and their Generic Computation. Social Networks 30(2):136-145, 2008. Edge betweenness centrality of an edge :math:`e` is the sum of the - fraction of all-pairs shortest paths that pass through :math`e` + fraction of all-pairs shortest paths that pass through :math:`e` .. math:: diff --git a/releasenotes/notes/add-find-cycle-and-cycle-basis-399d487def06239d.yaml b/releasenotes/notes/0.13/add-find-cycle-and-cycle-basis-399d487def06239d.yaml similarity index 100% rename from releasenotes/notes/add-find-cycle-and-cycle-basis-399d487def06239d.yaml rename to releasenotes/notes/0.13/add-find-cycle-and-cycle-basis-399d487def06239d.yaml diff --git a/releasenotes/notes/add-hits-dec9da09240e8787.yaml b/releasenotes/notes/0.13/add-hits-dec9da09240e8787.yaml similarity index 100% rename from releasenotes/notes/add-hits-dec9da09240e8787.yaml rename to releasenotes/notes/0.13/add-hits-dec9da09240e8787.yaml diff --git a/releasenotes/notes/add-katz-5389c6e5bd30e176.yaml b/releasenotes/notes/0.13/add-katz-5389c6e5bd30e176.yaml similarity index 100% rename from releasenotes/notes/add-katz-5389c6e5bd30e176.yaml rename to releasenotes/notes/0.13/add-katz-5389c6e5bd30e176.yaml diff --git a/releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml b/releasenotes/notes/0.13/add-longest-simple-path-afdc4538c49bc38f.yaml similarity index 100% rename from releasenotes/notes/add-longest-simple-path-afdc4538c49bc38f.yaml rename to releasenotes/notes/0.13/add-longest-simple-path-afdc4538c49bc38f.yaml diff --git a/releasenotes/notes/add-pagerank-bef0de7d46026071.yaml b/releasenotes/notes/0.13/add-pagerank-bef0de7d46026071.yaml similarity index 100% rename from releasenotes/notes/add-pagerank-bef0de7d46026071.yaml rename to releasenotes/notes/0.13/add-pagerank-bef0de7d46026071.yaml diff --git a/releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml b/releasenotes/notes/0.13/add-random-generators-9f99e57b5e4188f2.yaml similarity index 100% rename from releasenotes/notes/add-random-generators-9f99e57b5e4188f2.yaml rename to releasenotes/notes/0.13/add-random-generators-9f99e57b5e4188f2.yaml diff --git a/releasenotes/notes/add_bfs_predecessors-751b468d1b3c01de.yaml b/releasenotes/notes/0.13/add_bfs_predecessors-751b468d1b3c01de.yaml similarity index 100% rename from releasenotes/notes/add_bfs_predecessors-751b468d1b3c01de.yaml rename to releasenotes/notes/0.13/add_bfs_predecessors-751b468d1b3c01de.yaml diff --git a/releasenotes/notes/add_find_predecessor_node_by_edge-fcd525aff055c5fb.yaml b/releasenotes/notes/0.13/add_find_predecessor_node_by_edge-fcd525aff055c5fb.yaml similarity index 100% rename from releasenotes/notes/add_find_predecessor_node_by_edge-fcd525aff055c5fb.yaml rename to releasenotes/notes/0.13/add_find_predecessor_node_by_edge-fcd525aff055c5fb.yaml diff --git a/releasenotes/notes/0.13/added-token-swapper-bd168eeb5a31bd99.yaml b/releasenotes/notes/0.13/added-token-swapper-bd168eeb5a31bd99.yaml new file mode 100644 index 000000000..cf3c7abbb --- /dev/null +++ b/releasenotes/notes/0.13/added-token-swapper-bd168eeb5a31bd99.yaml @@ -0,0 +1,21 @@ +--- +features: + - | + Added a new function, :func:`~.graph_token_swapper`, which performs an + approximately optimal token swapping algorithm based on: + + Approximation and Hardness for Token Swapping by Miltzow et al. (2016) + https://arxiv.org/abs/1602.05150 + + that supports partial mappings (i.e. not-permutations) for graphs with + missing tokens. + - | + Added a new function ``token_swapper()`` to the new ``rustworkx-core`` + module ``rustworkx_core::token_swapper``. This function performs an + approximately optimal token swapping algorithm based on: + + Approximation and Hardness for Token Swapping by Miltzow et al. (2016) + https://arxiv.org/abs/1602.05150 + + that supports partial mappings (i.e. not-permutations) for graphs with + missing tokens. diff --git a/releasenotes/notes/bump-msrv-4c746d8fbd91ea7c.yaml b/releasenotes/notes/0.13/bump-msrv-4c746d8fbd91ea7c.yaml similarity index 74% rename from releasenotes/notes/bump-msrv-4c746d8fbd91ea7c.yaml rename to releasenotes/notes/0.13/bump-msrv-4c746d8fbd91ea7c.yaml index 2c0cca7b7..8213a77c2 100644 --- a/releasenotes/notes/bump-msrv-4c746d8fbd91ea7c.yaml +++ b/releasenotes/notes/0.13/bump-msrv-4c746d8fbd91ea7c.yaml @@ -3,5 +3,5 @@ upgrade: - | The minimum supported Rust version has been increased from 1.48 to 1.56.1. This applies to both building the rustworkx package from source as well as the - rustworkx-crate. This change was made to facilitate using newer versions of our + rustworkx-core crate. This change was made to facilitate using newer versions of our upstream dependencies as well as leveraging newer Rust language features. diff --git a/releasenotes/notes/closeness-centrality-459c5c7e35cb2e63.yaml b/releasenotes/notes/0.13/closeness-centrality-459c5c7e35cb2e63.yaml similarity index 100% rename from releasenotes/notes/closeness-centrality-459c5c7e35cb2e63.yaml rename to releasenotes/notes/0.13/closeness-centrality-459c5c7e35cb2e63.yaml diff --git a/releasenotes/notes/empty-and-complete-09d569b42cb6b9d5.yaml b/releasenotes/notes/0.13/empty-and-complete-09d569b42cb6b9d5.yaml similarity index 100% rename from releasenotes/notes/empty-and-complete-09d569b42cb6b9d5.yaml rename to releasenotes/notes/0.13/empty-and-complete-09d569b42cb6b9d5.yaml diff --git a/releasenotes/notes/0.13/fix-check-cycle-on-copy-d060a1781976f728.yaml b/releasenotes/notes/0.13/fix-check-cycle-on-copy-d060a1781976f728.yaml new file mode 100644 index 000000000..8c161d97d --- /dev/null +++ b/releasenotes/notes/0.13/fix-check-cycle-on-copy-d060a1781976f728.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed the :attr:`~.PyDiGraph.check_cycle` attribute not being preserved + when copying :class:`~.PyDiGraph` with ``copy.copy()`` and + ``copy.deepcopy()``. + Fixed `#836 `__ diff --git a/releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml b/releasenotes/notes/0.13/fix-edge-indices-pickle-83fddf149441fa9f.yaml similarity index 100% rename from releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml rename to releasenotes/notes/0.13/fix-edge-indices-pickle-83fddf149441fa9f.yaml diff --git a/releasenotes/notes/fix-priority-queue-compat-48c22f64a9208812.yaml b/releasenotes/notes/0.13/fix-priority-queue-compat-48c22f64a9208812.yaml similarity index 100% rename from releasenotes/notes/fix-priority-queue-compat-48c22f64a9208812.yaml rename to releasenotes/notes/0.13/fix-priority-queue-compat-48c22f64a9208812.yaml diff --git a/releasenotes/notes/fix-sequence-protocol-e95246e864cc850a.yaml b/releasenotes/notes/0.13/fix-sequence-protocol-e95246e864cc850a.yaml similarity index 100% rename from releasenotes/notes/fix-sequence-protocol-e95246e864cc850a.yaml rename to releasenotes/notes/0.13/fix-sequence-protocol-e95246e864cc850a.yaml diff --git a/releasenotes/notes/generators-module-rustworkx-core-c314df0b8a42aab4.yaml b/releasenotes/notes/0.13/generators-module-rustworkx-core-c314df0b8a42aab4.yaml similarity index 100% rename from releasenotes/notes/generators-module-rustworkx-core-c314df0b8a42aab4.yaml rename to releasenotes/notes/0.13/generators-module-rustworkx-core-c314df0b8a42aab4.yaml diff --git a/releasenotes/notes/0.13/graph-annotations-1d436930bf60c5c2.yaml b/releasenotes/notes/0.13/graph-annotations-1d436930bf60c5c2.yaml new file mode 100644 index 000000000..eefdfc25c --- /dev/null +++ b/releasenotes/notes/0.13/graph-annotations-1d436930bf60c5c2.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + Added partial type annotations to the library, including for the + :class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph` classes. + This enables statically type checking with + `mypy `__. diff --git a/releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml b/releasenotes/notes/0.13/handle-non-existent-edge-15d70cfe60c89ac2.yaml similarity index 100% rename from releasenotes/notes/handle-non-existent-edge-15d70cfe60c89ac2.yaml rename to releasenotes/notes/0.13/handle-non-existent-edge-15d70cfe60c89ac2.yaml diff --git a/releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml b/releasenotes/notes/0.13/migrate-greedy-color-c3239f35840eec18.yaml similarity index 100% rename from releasenotes/notes/migrate-greedy-color-c3239f35840eec18.yaml rename to releasenotes/notes/0.13/migrate-greedy-color-c3239f35840eec18.yaml diff --git a/releasenotes/notes/move-core-number-48efa7ef968b6736.yaml b/releasenotes/notes/0.13/move-core-number-48efa7ef968b6736.yaml similarity index 62% rename from releasenotes/notes/move-core-number-48efa7ef968b6736.yaml rename to releasenotes/notes/0.13/move-core-number-48efa7ef968b6736.yaml index c98d27573..58fec6662 100644 --- a/releasenotes/notes/move-core-number-48efa7ef968b6736.yaml +++ b/releasenotes/notes/0.13/move-core-number-48efa7ef968b6736.yaml @@ -1,6 +1,6 @@ --- features: - | - The function ``core_number``has been added to the ``rustworkx-core`` + The function ``core_number`` has been added to the ``rustworkx-core`` crate in the ``connectivity`` module. It computes the k-core number for the nodes in a graph. diff --git a/releasenotes/notes/0.13/prepare-0.13.0-5e579fb3ab1e3b60.yaml b/releasenotes/notes/0.13/prepare-0.13.0-5e579fb3ab1e3b60.yaml new file mode 100644 index 000000000..3104afa19 --- /dev/null +++ b/releasenotes/notes/0.13/prepare-0.13.0-5e579fb3ab1e3b60.yaml @@ -0,0 +1,17 @@ +--- +prelude: | + This release is major feature release of Rustworkx that adds some + new features to the library. The highlights of this release are: + + * An expansion of the functions exposed by rustworkx-core to including a + new graph generator module. + * New link analysis functions such as page rank + * Expanded centrality measure functions + * Added partial type annotations to the library including for the + :class:`~PyDiGraph` and :class:`~.PyGraph` classes. This enables + type checking with `mypy `__ + + This is also the final rustworkx release that supports running with Python + 3.7. Starting in the 0.14.0 release Python >= 3.8 will be required to use + rustworkx. This release also increased the minimum suported Rust version for + compiling rustworkx and rustworkx-core from source to 1.56.1. diff --git a/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml b/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml deleted file mode 100644 index e26ec104e..000000000 --- a/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -features: - - | - Added a new function, :func:`~.token_swapper()` which performs an - approximately optimal Token Swapping algorithm and supports partial - mappings (i.e. not-permutations) for graphs with missing tokens. diff --git a/releasenotes/notes/fix-check-cycle-on-copy-d060a1781976f728.yaml b/releasenotes/notes/fix-check-cycle-on-copy-d060a1781976f728.yaml deleted file mode 100644 index 228857d5d..000000000 --- a/releasenotes/notes/fix-check-cycle-on-copy-d060a1781976f728.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -fixes: - - | - Fixed the :attr:`~.PyDiGraph.check_cycle` attribute not being preserved - when copying :class:`~.PyDiGraph` with `copy.copy()` and `copy.deepcopy()` diff --git a/releasenotes/notes/graph-annotations-1d436930bf60c5c2.yaml b/releasenotes/notes/graph-annotations-1d436930bf60c5c2.yaml deleted file mode 100644 index df7a94409..000000000 --- a/releasenotes/notes/graph-annotations-1d436930bf60c5c2.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -features: - - | - Added type annotations to :class:`~retworkx.PyDiGraph` and - :class:`~retworkx.PyGraph`. They can now be statically type checked - with `mypy `__. From 2f41bd271b2689464fd122fb73e76aadfe4fc5dd Mon Sep 17 00:00:00 2001 From: Kevin Hartman Date: Tue, 6 Jun 2023 13:32:55 -0400 Subject: [PATCH 25/37] Fix typo in PyDiGraph.add_edge docstring. (#890) --- src/digraph.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index 353da8504..a48f2d8b6 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -1065,8 +1065,8 @@ impl PyDiGraph { /// Add an edge between 2 nodes. /// /// Use add_child() or add_parent() to create a node with an edge at the - /// same time as an edge for better performance. Using this method will - /// enable adding duplicate edges between nodes if the ``check_cycle`` + /// same time as an edge for better performance. Using this method + /// allows for adding duplicate edges between nodes if the ``multigraph`` /// attribute is set to ``True``. /// /// :param int parent: Index of the parent node From 629d04e9b1e053208a69ddc9f817ecc31fec2fe2 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 6 Jun 2023 23:28:46 -0400 Subject: [PATCH 26/37] Fix pickle/deepcopy node hole handling (#888) * Fix pickle/deepcopy node hole handling This commit fixes an issue introduced by #589 where in certain cases node holes in a graph would result in a panic being raised. This was caused by a logic bug in trying to recreate the holes. Additionally, there were several places where graph methods removed nodes that the flag to indicate there were removals would no be set. This commit fixes all of these issues so that deepcopy/pickle works as expected. * Fix test failures * Fix lint * Update src/digraph.rs --- ...x-removed-nodes-attr-d1829e1f4462d96a.yaml | 7 ++ src/digraph.rs | 83 +++++-------------- src/graph.rs | 78 +++++------------ .../rustworkx_tests/digraph/test_deepcopy.py | 26 ++++++ tests/rustworkx_tests/graph/test_deepcopy.py | 26 ++++++ 5 files changed, 98 insertions(+), 122 deletions(-) create mode 100644 releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml diff --git a/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml new file mode 100644 index 000000000..7900c2139 --- /dev/null +++ b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed an issue with several :class:`~.PyDiGraph` and :class:`~.PyGraph` + methods that removed nodes where previously when calling + these methods the :attr:`.PyDiGraph.node_removed` attribute would not be + updated to reflect that nodes were removed. diff --git a/src/digraph.rs b/src/digraph.rs index a48f2d8b6..6c177e775 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -29,7 +29,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -44,7 +44,7 @@ use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, - Visitable, + NodeIndexable, Visitable, }; use super::dot_utils::build_dot; @@ -318,7 +318,6 @@ impl PyDiGraph { }; edges.push(edge); } - let out_dict = PyDict::new(py); let nodes_lst: PyObject = PyList::new(py, nodes).into(); let edges_lst: PyObject = PyList::new(py, edges).into(); @@ -398,55 +397,22 @@ impl PyDiGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } // Remove any temporary nodes we added for tmp_node in tmp_nodes { @@ -463,20 +429,8 @@ impl PyDiGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1760,8 +1714,8 @@ impl PyDiGraph { /// the graph. #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -2389,7 +2343,7 @@ impl PyDiGraph { // If no nodes are copied bail here since there is nothing left // to do. if out_map.is_empty() { - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; // Return a new empty map to clear allocation from out_map return Ok(NodeMap { node_map: DictMap::new(), @@ -2450,7 +2404,7 @@ impl PyDiGraph { self._add_edge(source_out, target, weight)?; } // Remove node - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; Ok(NodeMap { node_map: out_map }) } @@ -2559,7 +2513,7 @@ impl PyDiGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -2912,7 +2866,10 @@ impl PyDiGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/src/graph.rs b/src/graph.rs index 75165cc4c..04c90c4c7 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -26,7 +26,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -47,6 +47,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, + NodeIndexable, }; /// A class for creating undirected graphs @@ -284,56 +285,24 @@ impl PyGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } + // Remove any temporary nodes we added for tmp_node in tmp_nodes { self.graph.remove_node(tmp_node); } @@ -348,20 +317,8 @@ impl PyGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1062,8 +1019,8 @@ impl PyGraph { /// the graph #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -1695,7 +1652,7 @@ impl PyGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -1846,7 +1803,10 @@ impl PyGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/tests/rustworkx_tests/digraph/test_deepcopy.py b/tests/rustworkx_tests/digraph/test_deepcopy.py index bd296a5a5..542251273 100644 --- a/tests/rustworkx_tests/digraph/test_deepcopy.py +++ b/tests/rustworkx_tests/digraph/test_deepcopy.py @@ -70,3 +70,29 @@ def test_deepcopy_different_objects(self): self.assertIsNot( graph_a.get_edge_data(node_a, node_b), graph_b.get_edge_data(node_a, node_b) ) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + ) diff --git a/tests/rustworkx_tests/graph/test_deepcopy.py b/tests/rustworkx_tests/graph/test_deepcopy.py index 074d03211..6941d2dcb 100644 --- a/tests/rustworkx_tests/graph/test_deepcopy.py +++ b/tests/rustworkx_tests/graph/test_deepcopy.py @@ -48,3 +48,29 @@ def test_deepcopy_attrs(self): graph = rustworkx.PyGraph(attrs="abc") graph_copy = copy.deepcopy(graph) self.assertEqual(graph.attrs, graph_copy.attrs) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + ) From 58fdd933cd81c576a3cc02c8a6b7b4e4646ea9c3 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Wed, 7 Jun 2023 13:12:38 -0400 Subject: [PATCH 27/37] Bump version strings post release (#892) Now that rustworkx 0.13.0 is released this commit bumps all the version strings for the rustworkx and rustworkx-core to be 0.14.0. This now indicates the development version on the main branch is 0.14.0 and differentiates it from the released 0.13.0. --- .mergify.yml | 2 +- Cargo.lock | 4 ++-- Cargo.toml | 4 ++-- docs/source/conf.py | 4 ++-- rustworkx-core/Cargo.toml | 2 +- setup.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.mergify.yml b/.mergify.yml index 957529cb3..f599bb38e 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -36,4 +36,4 @@ pull_request_rules: actions: backport: branches: - - stable/0.12 + - stable/0.13 diff --git a/Cargo.lock b/Cargo.lock index 14bbda764..c31ec5e2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -589,7 +589,7 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustworkx" -version = "0.13.0" +version = "0.14.0" dependencies = [ "ahash 0.8.0", "fixedbitset", @@ -615,7 +615,7 @@ dependencies = [ [[package]] name = "rustworkx-core" -version = "0.13.0" +version = "0.14.0" dependencies = [ "ahash 0.8.0", "fixedbitset", diff --git a/Cargo.toml b/Cargo.toml index 7894eecfd..70fd8efd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "rustworkx" description = "A python graph library implemented in Rust" -version = "0.13.0" +version = "0.14.0" authors = ["Matthew Treinish "] license = "Apache-2.0" readme = "README.md" @@ -33,7 +33,7 @@ ndarray-stats = "0.5.1" quick-xml = "0.28" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -rustworkx-core = { path = "rustworkx-core", version = "=0.13.0" } +rustworkx-core = { path = "rustworkx-core", version = "=0.14.0" } [dependencies.pyo3] version = "0.19.0" diff --git a/docs/source/conf.py b/docs/source/conf.py index ddc58fb81..72998d3be 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,9 +24,9 @@ docs_url_prefix = "ecosystem/rustworkx" # The short X.Y version. -version = '0.13.0' +version = '0.14.0' # The full version, including alpha/beta/rc tags. -release = '0.13.0' +release = '0.14.0' extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', diff --git a/rustworkx-core/Cargo.toml b/rustworkx-core/Cargo.toml index 33c7e1200..2d6620082 100644 --- a/rustworkx-core/Cargo.toml +++ b/rustworkx-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustworkx-core" -version = "0.13.0" +version = "0.14.0" edition = "2021" authors = ["Matthew Treinish "] description = "Rust APIs used for rustworkx algorithms" diff --git a/setup.py b/setup.py index 053825cc1..6e8121ce2 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def readme(): graphviz_extras = ['pillow>=5.4'] PKG_NAME = os.getenv('RUSTWORKX_PKG_NAME', "rustworkx") -PKG_VERSION = "0.13.0" +PKG_VERSION = "0.14.0" PKG_PACKAGES = ["rustworkx", "rustworkx.visualization"] PKG_INSTALL_REQUIRES = ['numpy>=1.16.0'] RUST_EXTENSIONS = [RustExtension("rustworkx.rustworkx", "Cargo.toml", From f631184fb14cfcd6581dfc823507a99da4ad31f9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 9 Jun 2023 11:49:12 +0000 Subject: [PATCH 28/37] Bump serde from 1.0.163 to 1.0.164 (#895) Bumps [serde](https://github.com/serde-rs/serde) from 1.0.163 to 1.0.164. - [Release notes](https://github.com/serde-rs/serde/releases) - [Commits](https://github.com/serde-rs/serde/compare/v1.0.163...v1.0.164) --- updated-dependencies: - dependency-name: serde dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c31ec5e2e..f27c61a80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -644,18 +644,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", From af4af62a277e6a6a7cdbcd2dd3097e632afca982 Mon Sep 17 00:00:00 2001 From: Edwin Navarro Date: Fri, 9 Jun 2023 12:57:54 -0700 Subject: [PATCH 29/37] Sort the todo_nodes (#897) --- rustworkx-core/src/token_swapper.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs index 1aa84848d..065c969f8 100644 --- a/rustworkx-core/src/token_swapper.rs +++ b/rustworkx-core/src/token_swapper.rs @@ -23,6 +23,7 @@ use petgraph::visit::{ }; use petgraph::Directed; use petgraph::Direction::{Incoming, Outgoing}; +use rayon::prelude::*; use rayon_cond::CondIterator; use crate::connectivity::find_cycle; @@ -127,10 +128,11 @@ where .collect(); // todo_nodes are all the mapping entries where left != right - let todo_nodes: Vec = tokens + let mut todo_nodes: Vec = tokens .iter() .filter_map(|(node, dest)| if node != dest { Some(*node) } else { None }) .collect(); + todo_nodes.par_sort(); // Add initial edges to the digraph/sub_digraph for node in self.graph.node_identifiers() { From 33c88713ab3dcfca8150b0aa2f12091cdb2a25d1 Mon Sep 17 00:00:00 2001 From: danielleodigie <97267313+danielleodigie@users.noreply.github.com> Date: Sun, 11 Jun 2023 06:10:51 -0400 Subject: [PATCH 30/37] Node and Edge Filtering (#886) Aims to resolve #800 Adds filter_nodes() and filter_edges() methods to the Graph and DiGraph classes * Implementing filter_nodes and filter_edges funcs * Running fmt and clippy * Fixed issue where errors were not being propagated up to Python. Created tests for filter_edges and filter_nodes for both PyGraph and PyDiGraph. Created release notes for the functions. * Ran fmt, clippy, and tox * Fixing release notes * Fixing release notes again * Fixing release notes again again * Fixed release notes * Fixed release notes. Changed Vec allocation. Expanded on documentation. * ran cargo fmt and clippy * Fixing docs for filter functions --- .../notes/add-filter-98d00f306b5689ee.yaml | 43 ++++++++++ src/digraph.rs | 78 ++++++++++++++++++ src/graph.rs | 78 ++++++++++++++++++ tests/rustworkx_tests/digraph/test_filter.py | 81 +++++++++++++++++++ tests/rustworkx_tests/graph/test_filter.py | 81 +++++++++++++++++++ 5 files changed, 361 insertions(+) create mode 100644 releasenotes/notes/add-filter-98d00f306b5689ee.yaml create mode 100644 tests/rustworkx_tests/digraph/test_filter.py create mode 100644 tests/rustworkx_tests/graph/test_filter.py diff --git a/releasenotes/notes/add-filter-98d00f306b5689ee.yaml b/releasenotes/notes/add-filter-98d00f306b5689ee.yaml new file mode 100644 index 000000000..c8313e765 --- /dev/null +++ b/releasenotes/notes/add-filter-98d00f306b5689ee.yaml @@ -0,0 +1,43 @@ +--- +features: + - | + The :class:`~rustworkx.PyGraph` and the :class:`~rustworkx.PyDiGraph` classes have a new method + :meth:`~rustworkx.PyGraph.filter_nodes` (or :meth:`~rustworkx.PyDiGraph.filter_nodes`). + This method returns a :class:`~.NodeIndices` object with the resulting nodes that fit some abstract criteria indicated by a filter function. + For example: + + .. jupyter-execute:: + + from rustworkx import PyGraph + + graph = PyGraph() + graph.add_nodes_from(list(range(5))) # Adds nodes from 0 to 5 + + def my_filter_function(node): + return node > 2 + + indices = graph.filter_nodes(my_filter_function) + print(indices) + + - | + The :class:`~rustworkx.PyGraph` and the :class:`~rustworkx.PyDiGraph` classes have a new method + :meth:`~rustworkx.PyGraph.filter_edges` (or :meth:`~rustworkx.PyDiGraph.filter_edges`). + This method returns a :class:`~.EdgeIndices` object with the resulting edges that fit some abstract criteria indicated by a filter function. + For example: + + .. jupyter-execute:: + + from rustworkx import PyGraph + from rustworkx.generators import complete_graph + + graph = PyGraph() + graph.add_nodes_from(range(3)) + graph.add_edges_from([(0, 1, 'A'), (0, 1, 'B'), (1, 2, 'C')]) + + def my_filter_function(edge): + if edge: + return edge == 'B' + return False + + indices = graph.filter_edges(my_filter_function) + print(indices) \ No newline at end of file diff --git a/src/digraph.rs b/src/digraph.rs index 6c177e775..0fe27c4c3 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -2843,6 +2843,84 @@ impl PyDiGraph { } } + /// Filters a graph's nodes by some criteria conditioned on a node's data payload and returns those nodes' indices. + /// + /// This function takes in a function as an argument. This filter function will be passed in a node's data payload and is + /// required to return a boolean value stating whether the node's data payload fits some criteria. + /// + /// For example:: + /// + /// from rustworkx import PyDiGraph + /// + /// graph = PyDiGraph() + /// graph.add_nodes_from(list(range(5))) + /// + /// def my_filter_function(node): + /// return node > 2 + /// + /// indices = graph.filter_nodes(my_filter_function) + /// assert indices == [3, 4] + /// + /// :param filter_function: Function with which to filter nodes + /// :returns: The node indices that match the filter + /// :rtype: NodeIndices + #[pyo3(text_signature = "(self, filter_function)")] + pub fn filter_nodes(&self, py: Python, filter_function: PyObject) -> PyResult { + let filter = |nindex: NodeIndex| -> PyResult { + let res = filter_function.call1(py, (&self.graph[nindex],))?; + res.extract(py) + }; + + let mut n = Vec::with_capacity(self.graph.node_count()); + for node_index in self.graph.node_indices() { + if filter(node_index)? { + n.push(node_index.index()) + }; + } + Ok(NodeIndices { nodes: n }) + } + + /// Filters a graph's edges by some criteria conditioned on a edge's data payload and returns those edges' indices. + /// + /// This function takes in a function as an argument. This filter function will be passed in an edge's data payload and is + /// required to return a boolean value stating whether the edge's data payload fits some criteria. + /// + /// For example:: + /// + /// from rustworkx import PyGraph + /// from rustworkx.generators import complete_graph + /// + /// graph = PyGraph() + /// graph.add_nodes_from(range(3)) + /// graph.add_edges_from([(0, 1, 'A'), (0, 1, 'B'), (1, 2, 'C')]) + /// + /// def my_filter_function(edge): + /// if edge: + /// return edge == 'B' + /// return False + /// + /// indices = graph.filter_edges(my_filter_function) + /// assert indices == [1] + /// + /// :param filter_function: Function with which to filter edges + /// :returns: The edge indices that match the filter + /// :rtype: EdgeIndices + #[pyo3(text_signature = "(self, filter_function)")] + pub fn filter_edges(&self, py: Python, filter_function: PyObject) -> PyResult { + let filter = |eindex: EdgeIndex| -> PyResult { + let res = filter_function.call1(py, (&self.graph[eindex],))?; + res.extract(py) + }; + + let mut e = Vec::with_capacity(self.graph.edge_count()); + for edge_index in self.graph.edge_indices() { + if filter(edge_index)? { + e.push(edge_index.index()) + }; + } + Ok(EdgeIndices { edges: e }) + } + /// Return the number of nodes in the graph fn __len__(&self) -> PyResult { Ok(self.graph.node_count()) diff --git a/src/graph.rs b/src/graph.rs index 04c90c4c7..3907514b1 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -1848,6 +1848,84 @@ impl PyGraph { self.node_removed = false; self.attrs = py.None(); } + + /// Filters a graph's nodes by some criteria conditioned on a node's data payload and returns those nodes' indices. + /// + /// This function takes in a function as an argument. This filter function will be passed in a node's data payload and is + /// required to return a boolean value stating whether the node's data payload fits some criteria. + /// + /// For example:: + /// + /// from rustworkx import PyGraph + /// + /// graph = PyGraph() + /// graph.add_nodes_from(list(range(5))) + /// + /// def my_filter_function(node): + /// return node > 2 + /// + /// indices = graph.filter_nodes(my_filter_function) + /// assert indices == [3, 4] + /// + /// :param filter_function: Function with which to filter nodes + /// :returns: The node indices that match the filter + /// :rtype: NodeIndices + #[pyo3(text_signature = "(self, filter_function)")] + pub fn filter_nodes(&self, py: Python, filter_function: PyObject) -> PyResult { + let filter = |nindex: NodeIndex| -> PyResult { + let res = filter_function.call1(py, (&self.graph[nindex],))?; + res.extract(py) + }; + + let mut n = Vec::with_capacity(self.graph.node_count()); + for node_index in self.graph.node_indices() { + if filter(node_index)? { + n.push(node_index.index()) + }; + } + Ok(NodeIndices { nodes: n }) + } + + /// Filters a graph's edges by some criteria conditioned on a edge's data payload and returns those edges' indices. + /// + /// This function takes in a function as an argument. This filter function will be passed in an edge's data payload and is + /// required to return a boolean value stating whether the edge's data payload fits some criteria. + /// + /// For example:: + /// + /// from rustworkx import PyGraph + /// from rustworkx.generators import complete_graph + /// + /// graph = PyGraph() + /// graph.add_nodes_from(range(3)) + /// graph.add_edges_from([(0, 1, 'A'), (0, 1, 'B'), (1, 2, 'C')]) + /// + /// def my_filter_function(edge): + /// if edge: + /// return edge == 'B' + /// return False + /// + /// indices = graph.filter_edges(my_filter_function) + /// assert indices == [1] + /// + /// :param filter_function: Function with which to filter edges + /// :returns: The edge indices that match the filter + /// :rtype: EdgeIndices + #[pyo3(text_signature = "(self, filter_function)")] + pub fn filter_edges(&self, py: Python, filter_function: PyObject) -> PyResult { + let filter = |eindex: EdgeIndex| -> PyResult { + let res = filter_function.call1(py, (&self.graph[eindex],))?; + res.extract(py) + }; + + let mut e = Vec::with_capacity(self.graph.edge_count()); + for edge_index in self.graph.edge_indices() { + if filter(edge_index)? { + e.push(edge_index.index()) + }; + } + Ok(EdgeIndices { edges: e }) + } } fn weight_transform_callable( diff --git a/tests/rustworkx_tests/digraph/test_filter.py b/tests/rustworkx_tests/digraph/test_filter.py new file mode 100644 index 000000000..2593068b2 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_filter.py @@ -0,0 +1,81 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx as rx + + +class TestFilter(unittest.TestCase): + def test_filter_nodes(self): + def my_filter_function1(node): + return node == "cat" + + def my_filter_function2(node): + return node == "lizard" + + def my_filter_function3(node): + return node == "human" + + graph = rx.PyDiGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_node("lizard") + graph.add_node("cat") + cat_indices = graph.filter_nodes(my_filter_function1) + lizard_indices = graph.filter_nodes(my_filter_function2) + human_indices = graph.filter_nodes(my_filter_function3) + self.assertEqual(list(cat_indices), [0, 1, 4]) + self.assertEqual(list(lizard_indices), [3]) + self.assertEqual(list(human_indices), []) + + def test_filter_edges(self): + def my_filter_function1(edge): + return edge == "friends" + + def my_filter_function2(edge): + return edge == "enemies" + + def my_filter_function3(node): + return node == "frenemies" + + graph = rx.PyDiGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_node("lizard") + graph.add_node("cat") + graph.add_edge(0, 2, "friends") + graph.add_edge(0, 1, "friends") + graph.add_edge(0, 3, "enemies") + friends_indices = graph.filter_edges(my_filter_function1) + enemies_indices = graph.filter_edges(my_filter_function2) + frenemies_indices = graph.filter_edges(my_filter_function3) + self.assertEqual(list(friends_indices), [0, 1]) + self.assertEqual(list(enemies_indices), [2]) + self.assertEqual(list(frenemies_indices), []) + + def test_filter_errors(self): + def my_filter_function1(node): + raise TypeError("error!") + + graph = rx.PyDiGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_edge(0, 1, "friends") + graph.add_edge(1, 2, "enemies") + with self.assertRaises(TypeError): + graph.filter_nodes(my_filter_function1) + with self.assertRaises(TypeError): + graph.filter_edges(my_filter_function1) diff --git a/tests/rustworkx_tests/graph/test_filter.py b/tests/rustworkx_tests/graph/test_filter.py new file mode 100644 index 000000000..9edb70d99 --- /dev/null +++ b/tests/rustworkx_tests/graph/test_filter.py @@ -0,0 +1,81 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx as rx + + +class TestFilter(unittest.TestCase): + def test_filter_nodes(self): + def my_filter_function1(node): + return node == "cat" + + def my_filter_function2(node): + return node == "lizard" + + def my_filter_function3(node): + return node == "human" + + graph = rx.PyGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_node("lizard") + graph.add_node("cat") + cat_indices = graph.filter_nodes(my_filter_function1) + lizard_indices = graph.filter_nodes(my_filter_function2) + human_indices = graph.filter_nodes(my_filter_function3) + self.assertEqual(list(cat_indices), [0, 1, 4]) + self.assertEqual(list(lizard_indices), [3]) + self.assertEqual(list(human_indices), []) + + def test_filter_edges(self): + def my_filter_function1(edge): + return edge == "friends" + + def my_filter_function2(edge): + return edge == "enemies" + + def my_filter_function3(node): + return node == "frenemies" + + graph = rx.PyGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_node("lizard") + graph.add_node("cat") + graph.add_edge(0, 2, "friends") + graph.add_edge(0, 1, "friends") + graph.add_edge(0, 3, "enemies") + friends_indices = graph.filter_edges(my_filter_function1) + enemies_indices = graph.filter_edges(my_filter_function2) + frenemies_indices = graph.filter_edges(my_filter_function3) + self.assertEqual(list(friends_indices), [0, 1]) + self.assertEqual(list(enemies_indices), [2]) + self.assertEqual(list(frenemies_indices), []) + + def test_filter_errors(self): + def my_filter_function1(node): + raise TypeError("error!") + + graph = rx.PyGraph() + graph.add_node("cat") + graph.add_node("cat") + graph.add_node("dog") + graph.add_edge(0, 1, "friends") + graph.add_edge(1, 2, "enemies") + with self.assertRaises(TypeError): + graph.filter_nodes(my_filter_function1) + with self.assertRaises(TypeError): + graph.filter_edges(my_filter_function1) From 36d0921b006650a871b3494568cc3c14943ee81f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 13 Jun 2023 12:07:58 -0400 Subject: [PATCH 31/37] Bump quick-xml from 0.28.2 to 0.29.0 (#900) Bumps [quick-xml](https://github.com/tafia/quick-xml) from 0.28.2 to 0.29.0. - [Release notes](https://github.com/tafia/quick-xml/releases) - [Changelog](https://github.com/tafia/quick-xml/blob/master/Changelog.md) - [Commits](https://github.com/tafia/quick-xml/compare/v0.28.2...v0.29.0) --- updated-dependencies: - dependency-name: quick-xml dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f27c61a80..92ad278c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -478,9 +478,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.28.2" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce5e73202a820a31f8a0ee32ada5e21029c81fd9e3ebf668a40832e4219d9d1" +checksum = "81b9228215d82c7b61490fec1de287136b5de6f5700f6e58ea9ad61a7964ca51" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 70fd8efd8..94873ca2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ num-traits = "0.2" num-bigint = "0.4" num-complex = "0.4" ndarray-stats = "0.5.1" -quick-xml = "0.28" +quick-xml = "0.29" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" rustworkx-core = { path = "rustworkx-core", version = "=0.14.0" } From 8fe10cbafcae7af96e8f1a76a8719e108e3d307d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 16 Jun 2023 11:23:16 -0400 Subject: [PATCH 32/37] Bump serde_json from 1.0.96 to 1.0.97 (#903) Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.96 to 1.0.97. - [Release notes](https://github.com/serde-rs/json/releases) - [Commits](https://github.com/serde-rs/json/compare/v1.0.96...v1.0.97) --- updated-dependencies: - dependency-name: serde_json dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92ad278c3..ceae8f2b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -664,9 +664,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a" dependencies = [ "itoa", "ryu", From 8720cd3326f1c3375410ddb99c9a5916432b51df Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 16 Jun 2023 17:27:18 -0400 Subject: [PATCH 33/37] Apply wheel job fixes from 0.13.0 release (#896) * Apply wheel job fixes from 0.13.0 release This commit applies the various fixes needed to the wheel publishing job definitions for the 0.13.0 release. There were several jon errors during the release which were caused by various changes made to rustworkx, upstream dependencies, and CI environment since the 0.12.0 which caused issues during the release process. #753 should still be finished to simplify the job definitions, but that should be rebased to take this more targeted fix. The intent is for this to be a minimal diff for backporting to stable/0.13 for a future 0.13.1 release. Of particular importance here though is the change in support tier for s390x from 3 to 4. This was caused by repeated issues with timeouts caused by running tests during the s390x linux wheel builds. To ensure we can reliably build the wheels this drops the testing from the s390x so that they can reliably complete in 12 job hours. * Fix docs * Update releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml * Update releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml Co-authored-by: Edwin Navarro --------- Co-authored-by: Edwin Navarro --- .github/workflows/wheels.yml | 22 ++++++++++--------- docs/source/install.rst | 21 ++++++++++-------- .../notes/s390x-tier-4-1701a0f044759cd1.yaml | 12 ++++++++++ 3 files changed, 36 insertions(+), 19 deletions(-) create mode 100644 releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index b94893ecd..6780979d4 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -67,7 +67,7 @@ jobs: CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest CIBW_SKIP: cp36-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust - CIBW_TEST_REQUIRES: networkx scipy testtools fixtures + CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests - uses: actions/upload-artifact@v3 with: @@ -92,7 +92,7 @@ jobs: python-version: '3.7' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Install cibuildwheel @@ -135,7 +135,7 @@ jobs: python-version: '3.7' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Install cibuildwheel @@ -150,7 +150,7 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* cp39-* cp310-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp39-* cp310-* cp311-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests @@ -178,7 +178,7 @@ jobs: python-version: '3.7' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Install cibuildwheel @@ -221,7 +221,7 @@ jobs: python-version: '3.7' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Install cibuildwheel @@ -236,11 +236,12 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* cp39-* cp310-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp39-* cp310-* cp311-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests CIBW_ARCHS_LINUX: s390x + CIBW_TEST_SKIP: "*-*linux_s390x" - uses: actions/upload-artifact@v3 with: path: ./wheelhouse/*.whl @@ -264,7 +265,7 @@ jobs: python-version: '3.7' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 with: platforms: all - name: Install cibuildwheel @@ -284,6 +285,7 @@ jobs: CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests CIBW_ARCHS_LINUX: s390x + CIBW_TEST_SKIP: "*-*linux_s390x" - uses: actions/upload-artifact@v3 with: path: ./wheelhouse/*.whl @@ -294,7 +296,7 @@ jobs: TWINE_USERNAME: retworkx-ci build-mac-arm-wheels: name: Build wheels on macos for arm and universal2 - runs-on: macos-10.15 + runs-on: macos-latest steps: - uses: actions/checkout@v3 - name: Build wheels @@ -365,7 +367,7 @@ jobs: with: python-version: '3.10' - name: Install deps - run: pip install -U twine setuptools-rust + run: pip install -U twine setuptools-rust wheel build - name: Build sdist run: python setup.py bdist_wheel env: diff --git a/docs/source/install.rst b/docs/source/install.rst index ed2964696..1f1ed32a3 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -10,8 +10,8 @@ Installing Rustworkx ==================== rustworkx is published on pypi so on x86_64, i686, ppc64le, s390x, and aarch64 -Linux systems, x86_64 on Mac OSX, and 32 and 64 bit Windows installing is as -simple as running:: +Linux systems, x86_64 and arm64 on macOS, and 32 and 64 bit Windows +installing is as simple as running:: pip install rustworkx @@ -71,28 +71,28 @@ source. * - Linux - x86_64 - :ref:`tier-1` - - Distributions compatible with the [manylinux 2014](https://peps.python.org/pep-0599/) packaging specification + - Distributions compatible with the `manylinux 2014`_ packaging specification * - Linux - i686 - :ref:`tier-2` (Python < 3.10), :ref:`tier-3` (Python >= 3.10) - - Distributions compatible with the [manylinux 2014](https://peps.python.org/pep-0599/) packaging specification + - Distributions compatible with the `manylinux 2014`_ packaging specification * - Linux - aarch64 - :ref:`tier-2` - - Distributions compatible with the [manylinux 2014](https://peps.python.org/pep-0599/) packaging specification + - Distributions compatible with the `manylinux 2014`_ packaging specification * - Linux - pp64le - :ref:`tier-3` - - Distributions compatible with the [manylinux 2014](https://peps.python.org/pep-0599/) packaging specification + - Distributions compatible with the `manylinux 2014`_ packaging specification * - Linux - s390x - - :ref:`tier-3` - - Distributions compatible with the [manylinux 2014](https://peps.python.org/pep-0599/) packaging specification + - :ref:`tier-4` + - Distributions compatible with the `manylinux 2014`_ packaging specification * - macOS (10.9 or newer) - x86_64 - :ref:`tier-1` - - * - macOS (10.15 or newer) + * - macOS (11 or newer) - arm64 - :ref:`tier-4` - @@ -105,6 +105,9 @@ source. - :ref:`tier-2` (Python < 3.10), :ref:`tier-3` (Python >= 3.10) - + +.. _manylinux 2014: https://peps.python.org/pep-0599/> + .. _tier-1: Tier 1 diff --git a/releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml b/releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml new file mode 100644 index 000000000..b4f3b3f9a --- /dev/null +++ b/releasenotes/notes/s390x-tier-4-1701a0f044759cd1.yaml @@ -0,0 +1,12 @@ +--- +upgrade: + - | + Support for the Linux s390x platform has changed from tier 3 to tier 4 (as + documented in :ref:`platform-suppport`). This is a result of no longer being + able to run tests during the pre-compiled wheel publishing jobs due to + constraints in the available CI infrastructure. There hopefully shouldn't + be any meaningful impact resulting from this change, but as there are no longer tests being + run to validate the binaries prior to publishing them there are no longer + guarantees that the wheels for s390x are fully functional (although the + likelihood they are is still high as it works on other platforms). If any + issues are encountered with s390x Linux please open an issue. From 5f73e767fb74ec4ff8e9cf6a2dc6de2c95f30144 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 20 Jun 2023 09:45:50 -0400 Subject: [PATCH 34/37] Pin Python version in CI (#905) * Pin Python version in CI Recently github updated the cached version of python installed for 3.7 in the CI environment to 3.7.17. This new binary was not built with bz2 support which is causing failures in our CI jobs that run with 3.7. While we'll be dropping 3.7 support from the main branch (for 0.14.0) in the near future we still support 3.7 on the stable 0.13.x series. This commit pins the python version to the previous patch release which was known to work. * Use 3.7 on windows * Fix windows capitalization * Fix syntax error --- .github/workflows/main.yml | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 2520c934f..1f8642482 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -57,7 +57,7 @@ jobs: strategy: matrix: rust: [stable] - python-version: [3.7, 3.8, 3.9, "3.10", "3.11"] + python-version: ['3.7.16', 3.8, 3.9, "3.10", "3.11"] platform: [ { os: "macOS-latest", python-architecture: "x64", rust-target: "x86_64-apple-darwin" }, { os: "ubuntu-latest", python-architecture: "x64", rust-target: "x86_64-unknown-linux-gnu" }, @@ -76,6 +76,20 @@ jobs: with: python-version: ${{ matrix.python-version }} architecture: ${{ matrix.platform.python-architecture }} + if: runner.os != 'Windows' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: 3.7 + architecture: ${{ matrix.platform.python-architecture }} + if: ${{ runner.os == 'Windows' && matrix.python-version == '3.7.16' }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: ${{ matrix.platform.python-architecture }} + if: ${{ runner.os == 'Windows' && matrix.python-version != '3.7.16' }} + - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: From 8aa5289a6c99a7c92b7c9ca65cfd2bae07dd4bcf Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Tue, 20 Jun 2023 10:37:31 -0400 Subject: [PATCH 35/37] Implement substitute_node_with_subgraph to Pygraph (#894) * Initial: Add PyGraph.substitute_node_with_subgraph - Added function that substitutes a node with a subgraph. - Implemented from the existent method in `PyDiGraph`. * Stubs: Add substitute_node_with_subgraph to graph * Docs: Add release notes * Test: Add python tests * Docs: Fix typo in release-notes * Docs: Add docs string for method * Correction: Change `unwrap` to `?` --------- Co-authored-by: Edwin Navarro --- ...ode-subgraph-pygraph-44f40c01b783bb50.yaml | 25 +++ rustworkx/graph.pyi | 9 + src/graph.rs | 157 ++++++++++++++++++ .../test_substitute_node_with_subgraph.py | 141 ++++++++++++++++ 4 files changed, 332 insertions(+) create mode 100644 releasenotes/notes/add-substitute-node-subgraph-pygraph-44f40c01b783bb50.yaml create mode 100644 tests/rustworkx_tests/graph/test_substitute_node_with_subgraph.py diff --git a/releasenotes/notes/add-substitute-node-subgraph-pygraph-44f40c01b783bb50.yaml b/releasenotes/notes/add-substitute-node-subgraph-pygraph-44f40c01b783bb50.yaml new file mode 100644 index 000000000..49e023678 --- /dev/null +++ b/releasenotes/notes/add-substitute-node-subgraph-pygraph-44f40c01b783bb50.yaml @@ -0,0 +1,25 @@ +--- +features: + - | + Added method substitute_node_with_subgraph to the PyGraph class. + + .. jupyter-execute:: + + import rustworkx + from rustworkx.visualization import * # Needs matplotlib/ + + graph = rustworkx.generators.complete_graph(5) + sub_graph = rustworkx.generators.path_graph(3) + + # Replace node 4 in this graph with sub_graph + # Make sure to connect the graphs at node 2 of the sub_graph + # This is done by passing a function that returns 2 + + graph.substitute_node_with_subgraph(4, sub_graph, lambda _, __, ___: 2) + + # Draw the updated graph + mpl_draw(graph, with_labels=True) +fixes: + - | + Fixes missing method that is present in PyDiGraph but not in PyGraph. + see `#837 `__ for more info. \ No newline at end of file diff --git a/rustworkx/graph.pyi b/rustworkx/graph.pyi index 7277c1eda..44eb37260 100644 --- a/rustworkx/graph.pyi +++ b/rustworkx/graph.pyi @@ -97,6 +97,15 @@ class PyGraph(Generic[S, T]): def remove_node(self, node: int, /) -> None: ... def remove_nodes_from(self, index_list: Sequence[int], /) -> None: ... def subgraph(self, nodes: Sequence[int], /, preserve_attrs: bool = ...) -> PyGraph[S, T]: ... + def substitute_node_with_subgraph( + self, + node: int, + other: PyGraph[S, T], + edge_map_fn: Callable[[int, int, T], Optional[int]], + /, + node_filter: Optional[Callable[[S], bool]] = ..., + edge_weight_map: Optional[Callable[[T], T]] = ..., + ) -> NodeMap: ... def to_dot( self, /, diff --git a/src/graph.rs b/src/graph.rs index 3907514b1..1d97404c5 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -35,6 +35,8 @@ use num_traits::Zero; use numpy::Complex64; use numpy::PyReadonlyArray2; +use crate::iterators::NodeMap; + use super::dot_utils::build_dot; use super::iterators::{EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; use super::{ @@ -1594,6 +1596,161 @@ impl PyGraph { Ok(out_dict.into()) } + /// Substitute a node with a PyGraph object + /// + /// :param int node: The node to replace with the PyGraph object + /// :param PyGraph other: The other graph to replace ``node`` with + /// :param callable edge_map_fn: A callable object that will take 3 position + /// parameters, ``(source, target, weight)`` to represent an edge either to + /// or from ``node`` in this graph. The expected return value from this + /// callable is the node index of the node in ``other`` that an edge should + /// be to/from. If None is returned, that edge will be skipped and not + /// be copied. + /// :param callable node_filter: An optional callable object that when used + /// will receive a node's payload object from ``other`` and return + /// ``True`` if that node is to be included in the graph or not. + /// :param callable edge_weight_map: An optional callable object that when + /// used will receive an edge's weight/data payload from ``other`` and + /// will return an object to use as the weight for a newly created edge + /// after the edge is mapped from ``other``. If not specified the weight + /// from the edge in ``other`` will be copied by reference and used. + /// + /// :returns: A mapping of node indices in ``other`` to the equivalent node + /// in this graph. + /// :rtype: NodeMap + /// + /// .. note:: + /// + /// The return type is a :class:`rustworkx.NodeMap` which is an unordered + /// type. So it does not provide a deterministic ordering between objects + /// when iterated over (although the same object will have a consistent + /// order when iterated over multiple times). + /// + #[pyo3( + text_signature = "(self, node, other, edge_map_fn, /, node_filter=None, edge_weight_map=None" + )] + fn substitute_node_with_subgraph( + &mut self, + py: Python, + node: usize, + other: &PyGraph, + edge_map_fn: PyObject, + node_filter: Option, + edge_weight_map: Option, + ) -> PyResult { + let filter_fn = |obj: &PyObject, filter_fn: &Option| -> PyResult { + match filter_fn { + Some(filter) => { + let res = filter.call1(py, (obj,))?; + res.extract(py) + } + None => Ok(true), + } + }; + + let weight_map_fn = |obj: &PyObject, weight_fn: &Option| -> PyResult { + match weight_fn { + Some(weight_fn) => weight_fn.call1(py, (obj,)), + None => Ok(obj.clone_ref(py)), + } + }; + + let map_fn = |source: usize, target: usize, weight: &PyObject| -> PyResult> { + let res = edge_map_fn.call1(py, (source, target, weight))?; + res.extract(py) + }; + + let node_index = NodeIndex::new(node); + if self.graph.node_weight(node_index).is_none() { + return Err(PyIndexError::new_err(format!( + "Specified node {} is not in this graph", + node + ))); + } + + // Copy all nodes from other to self + let mut out_map: DictMap = DictMap::with_capacity(other.node_count()); + for node in other.graph.node_indices() { + let node_weight: Py = other.graph[node].clone_ref(py); + if !filter_fn(&node_weight, &node_filter)? { + continue; + } + let new_index: NodeIndex = self.graph.add_node(node_weight); + out_map.insert(node.index(), new_index.index()); + } + + if out_map.is_empty() { + self.graph.remove_node(node_index); + return Ok(NodeMap { + node_map: DictMap::new(), + }); + } + + // Copy all edges + for edge in other.graph.edge_references().filter(|edge| { + out_map.contains_key(&edge.target().index()) + && out_map.contains_key(&edge.source().index()) + }) { + self._add_edge( + NodeIndex::new(out_map[&edge.source().index()]), + NodeIndex::new(out_map[&edge.target().index()]), + weight_map_fn(edge.weight(), &edge_weight_map)?, + ); + } + // Incoming and outgoing edges. + let in_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self + .graph + .edge_references() + .filter(|edge| edge.target() == node_index) + .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) + .collect(); + // Keep track of what's present on incoming edges + let in_set: HashSet<(NodeIndex, NodeIndex)> = + in_edges.iter().map(|edge| (edge.0, edge.1)).collect(); + // Retrieve outgoing edges. Make sure to not include any incoming edge. + let out_edges: Vec<(NodeIndex, NodeIndex, PyObject)> = self + .graph + .edges(node_index) + .filter(|edge| !in_set.contains(&(edge.target(), edge.source()))) + .map(|edge| (edge.source(), edge.target(), edge.weight().clone_ref(py))) + .collect(); + for (source, target, weight) in in_edges { + let old_index: Option = map_fn(source.index(), target.index(), &weight)?; + let target_out: NodeIndex = match old_index { + Some(old_index) => match out_map.get(&old_index) { + Some(new_index) => NodeIndex::new(*new_index), + None => { + return Err(PyIndexError::new_err(format!( + "No mapped index {} found", + old_index + ))) + } + }, + None => continue, + }; + self._add_edge(source, target_out, weight); + } + for (source, target, weight) in out_edges { + let old_index: Option = map_fn(source.index(), target.index(), &weight)?; + let source_out: NodeIndex = match old_index { + Some(old_index) => match out_map.get(&old_index) { + Some(new_index) => NodeIndex::new(*new_index), + None => { + return Err(PyIndexError::new_err(format!( + "No mapped index {} found", + old_index + ))) + } + }, + None => continue, + }; + self._add_edge(source_out, target, weight); + } + // Remove original node + self.graph.remove_node(node_index); + Ok(NodeMap { node_map: out_map }) + } + /// Substitute a set of nodes with a single new node. /// /// .. note:: diff --git a/tests/rustworkx_tests/graph/test_substitute_node_with_subgraph.py b/tests/rustworkx_tests/graph/test_substitute_node_with_subgraph.py new file mode 100644 index 000000000..2ddd8bf3e --- /dev/null +++ b/tests/rustworkx_tests/graph/test_substitute_node_with_subgraph.py @@ -0,0 +1,141 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import rustworkx + + +class TestSubstituteNodeSubGraph(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + self.graph = rustworkx.generators.path_graph(5) + + def test_empty_replacement(self): + in_graph = rustworkx.PyGraph() + res = self.graph.substitute_node_with_subgraph(3, in_graph, lambda _, __, ___: None) + self.assertEqual(res, {}) + self.assertEqual([(0, 1), (1, 2)], self.graph.edge_list()) + + def test_single_node(self): + in_graph = rustworkx.generators.path_graph(1) + res = self.graph.substitute_node_with_subgraph(2, in_graph, lambda _, __, ___: 0) + self.assertEqual(res, {0: 5}) + self.assertEqual([(0, 1), (1, 5), (3, 4), (5, 3)], sorted(self.graph.edge_list())) + + def test_node_filter(self): + in_graph = rustworkx.generators.complete_graph(5) + res = self.graph.substitute_node_with_subgraph( + 0, in_graph, lambda _, __, ___: 2, node_filter=lambda node: node == None + ) + self.assertEqual(res, {i: i + 5 for i in range(5)}) + self.assertEqual( + [ + (1, 2), + (2, 3), + (3, 4), + (5, 6), + (5, 7), + (5, 8), + (5, 9), + (6, 7), + (6, 8), + (6, 9), + (7, 1), + (7, 8), + (7, 9), + (8, 9), + ], + sorted(self.graph.edge_list()), + ) + + def test_edge_weight_modifier(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node("meep") + in_graph.add_node("moop") + in_graph.add_edges_from( + [ + ( + 0, + 1, + "edge", + ) + ] + ) + res = self.graph.substitute_node_with_subgraph( + 2, + in_graph, + lambda _, __, ___: 0, + edge_weight_map=lambda edge: edge + "-migrated", + ) + self.assertEqual([(0, 1), (3, 4), (5, 6), (1, 5), (5, 3)], self.graph.edge_list()) + self.assertEqual("edge-migrated", self.graph.get_edge_data(5, 6)) + self.assertEqual(res, {0: 5, 1: 6}) + + def test_none_mapping(self): + in_graph = rustworkx.PyGraph() + in_graph.add_node("boop") + in_graph.add_node("beep") + in_graph.add_edges_from([(0, 1, "edge")]) + res = self.graph.substitute_node_with_subgraph(2, in_graph, lambda _, __, ___: None) + self.assertEqual([(0, 1), (3, 4), (5, 6)], self.graph.edge_list()) + self.assertEqual(res, {0: 5, 1: 6}) + + def test_multiple_mapping(self): + graph = rustworkx.generators.star_graph(5) + in_graph = rustworkx.generators.star_graph(3) + + def map_function(_source, target, _weight): + if target > 2: + return 2 + return 1 + + res = graph.substitute_node_with_subgraph(0, in_graph, map_function) + self.assertEqual({0: 5, 1: 6, 2: 7}, res) + expected = [(5, 6), (5, 7), (7, 4), (7, 3), (6, 2), (6, 1)] + self.assertEqual(sorted(expected), sorted(graph.edge_list())) + + def test_multiple_mapping_full(self): + graph = rustworkx.generators.star_graph(5) + in_graph = rustworkx.generators.star_graph(weights=list(range(3))) + in_graph.add_edge(1, 2, None) + + def map_function(source, target, _weight): + if target > 2: + return 2 + return 1 + + def filter_fn(node): + return node > 0 + + def map_weight(_): + return "migrated" + + res = graph.substitute_node_with_subgraph(0, in_graph, map_function, filter_fn, map_weight) + self.assertEqual({1: 5, 2: 6}, res) + expected = [ + (5, 6, "migrated"), + (6, 4, None), + (6, 3, None), + (5, 2, None), + (5, 1, None), + ] + self.assertEqual(expected, graph.weighted_edge_list()) + + def test_invalid_target(self): + in_graph = rustworkx.generators.grid_graph(5, 5) + with self.assertRaises(IndexError): + self.graph.substitute_node_with_subgraph(0, in_graph, lambda *args: 42) + + def test_invalid_node_id(self): + in_graph = rustworkx.generators.grid_graph(5, 5) + with self.assertRaises(IndexError): + self.graph.substitute_node_with_subgraph(16, in_graph, lambda *args: None) From a468efd7508ad490ed184ee541317dd224ea4110 Mon Sep 17 00:00:00 2001 From: danielleodigie <97267313+danielleodigie@users.noreply.github.com> Date: Tue, 20 Jun 2023 12:21:21 -0400 Subject: [PATCH 36/37] Adding Option to change parallel edge behavior in adjacency_matrix functions (#899) * working on adding different parallel edge behavior * Working on graph_adjacency_matrix * Implementing changes to graph_adjacency_matrix and digraph_adjacency_matrix * working on release notes * Fixed release notes and docs * Ran cargo fmt * Ran cargo clippy * Fixed digraph_adjacency_matrix, passes tests * Removed mpl_draw from r elease notes * Changed if-else blocks in adjacency_matrix functions to match blocks. Wrote tests. * Fixed tests to pass lint --------- Co-authored-by: Edwin Navarro --- ...and-adjacency-matrix-11e56c1f49b8e4e5.yaml | 32 +++++++ src/connectivity/mod.rs | 86 +++++++++++++++++-- .../digraph/test_adjacency_matrix.py | 51 +++++++++++ .../graph/test_adjencency_matrix.py | 51 +++++++++++ 4 files changed, 211 insertions(+), 9 deletions(-) create mode 100644 releasenotes/notes/expand-adjacency-matrix-11e56c1f49b8e4e5.yaml diff --git a/releasenotes/notes/expand-adjacency-matrix-11e56c1f49b8e4e5.yaml b/releasenotes/notes/expand-adjacency-matrix-11e56c1f49b8e4e5.yaml new file mode 100644 index 000000000..f15411209 --- /dev/null +++ b/releasenotes/notes/expand-adjacency-matrix-11e56c1f49b8e4e5.yaml @@ -0,0 +1,32 @@ +--- +features: + - | + The functions :func:`~rustworkx.graph_adjacency_matrix` and :func:`~rustworkx.digraph_adjacency_matrix` now have the option to adjust parallel edge behavior. + Instead of just the default sum behavior, the value in the output matrix can be the minimum ("min"), maximum ("max"), or average ("avg") of the weights of the parallel edges. + For example: + + .. jupyter-execute:: + + import rustworkx as rx + graph = rx.PyGraph() + a = graph.add_node("A") + b = graph.add_node("B") + c = graph.add_node("C") + + graph.add_edges_from([ + (a, b, 3.0), + (a, b, 1.0), + (a, c, 2.0), + (b, c, 7.0), + (c, a, 1.0), + (b, c, 2.0), + (a, b, 4.0) + ]) + + print("Adjacency Matrix with Summed Parallel Edges") + print(rx.graph_adjacency_matrix(graph, weight_fn= lambda x: float(x))) + print("Adjacency Matrix with Averaged Parallel Edges") + print(rx.graph_adjacency_matrix(graph, weight_fn= lambda x: float(x), parallel_edge="avg")) + + + \ No newline at end of file diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index 70cfbb6d2..7d38b3cf9 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -266,7 +266,7 @@ pub fn is_weakly_connected(graph: &digraph::PyDiGraph) -> PyResult { /// Return the adjacency matrix for a PyDiGraph object /// /// In the case where there are multiple edges between nodes the value in the -/// output matrix will be the sum of the edges' weights. +/// output matrix will be assigned based on a given parameter. Currently, the minimum, maximum, average, and default sum are supported. /// /// :param PyDiGraph graph: The DiGraph used to generate the adjacency matrix /// from @@ -290,13 +290,16 @@ pub fn is_weakly_connected(graph: &digraph::PyDiGraph) -> PyResult { /// value. This is the default value in the output matrix and it is used /// to indicate the absence of an edge between 2 nodes. By default this is /// ``0.0``. +/// :param String parallel_edge: Optional argument that determines how the function handles parallel edges. +/// ``"min"`` causes the value in the output matrix to be the minimum of the edges' weights, and similar behavior can be expected for ``"max"`` and ``"avg"``. +/// The function defaults to ``"sum"`` behavior, where the value in the output matrix is the sum of all parallel edge weights. /// /// :return: The adjacency matrix for the input directed graph as a numpy array /// :rtype: numpy.ndarray #[pyfunction] #[pyo3( - signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0), - text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0)" + signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum"), + text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\")" )] pub fn digraph_adjacency_matrix( py: Python, @@ -304,15 +307,43 @@ pub fn digraph_adjacency_matrix( weight_fn: Option, default_weight: f64, null_value: f64, + parallel_edge: &str, ) -> PyResult { let n = graph.node_count(); let mut matrix = Array2::::from_elem((n, n), null_value); + let mut parallel_edge_count = HashMap::new(); for (i, j, weight) in get_edge_iter_with_weights(&graph.graph) { let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?; if matrix[[i, j]] == null_value || (null_value.is_nan() && matrix[[i, j]].is_nan()) { matrix[[i, j]] = edge_weight; } else { - matrix[[i, j]] += edge_weight; + match parallel_edge { + "sum" => { + matrix[[i, j]] += edge_weight; + } + "min" => { + let weight_min = matrix[[i, j]].min(edge_weight); + matrix[[i, j]] = weight_min; + } + "max" => { + let weight_max = matrix[[i, j]].max(edge_weight); + matrix[[i, j]] = weight_max; + } + "avg" => { + if parallel_edge_count.contains_key(&[i, j]) { + matrix[[i, j]] = (matrix[[i, j]] * parallel_edge_count[&[i, j]] as f64 + + edge_weight) + / ((parallel_edge_count[&[i, j]] + 1) as f64); + *parallel_edge_count.get_mut(&[i, j]).unwrap() += 1; + } else { + parallel_edge_count.insert([i, j], 2); + matrix[[i, j]] = (matrix[[i, j]] + edge_weight) / 2.0; + } + } + _ => { + return Err(PyValueError::new_err("Parallel edges can currently only be dealt with using \"sum\", \"min\", \"max\", or \"avg\".")); + } + } } } Ok(matrix.into_pyarray(py).into()) @@ -321,7 +352,7 @@ pub fn digraph_adjacency_matrix( /// Return the adjacency matrix for a PyGraph class /// /// In the case where there are multiple edges between nodes the value in the -/// output matrix will be the sum of the edges' weights. +/// output matrix will be assigned based on a given parameter. Currently, the minimum, maximum, average, and default sum are supported. /// /// :param PyGraph graph: The graph used to generate the adjacency matrix from /// :param weight_fn: A callable object (function, lambda, etc) which @@ -344,13 +375,16 @@ pub fn digraph_adjacency_matrix( /// value. This is the default value in the output matrix and it is used /// to indicate the absence of an edge between 2 nodes. By default this is /// ``0.0``. +/// :param String parallel_edge: Optional argument that determines how the function handles parallel edges. +/// ``"min"`` causes the value in the output matrix to be the minimum of the edges' weights, and similar behavior can be expected for ``"max"`` and ``"avg"``. +/// The function defaults to ``"sum"`` behavior, where the value in the output matrix is the sum of all parallel edge weights. /// /// :return: The adjacency matrix for the input graph as a numpy array /// :rtype: numpy.ndarray #[pyfunction] #[pyo3( - signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0), - text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0)" + signature=(graph, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge="sum"), + text_signature = "(graph, /, weight_fn=None, default_weight=1.0, null_value=0.0, parallel_edge=\"sum\")" )] pub fn graph_adjacency_matrix( py: Python, @@ -358,17 +392,51 @@ pub fn graph_adjacency_matrix( weight_fn: Option, default_weight: f64, null_value: f64, + parallel_edge: &str, ) -> PyResult { let n = graph.node_count(); let mut matrix = Array2::::from_elem((n, n), null_value); + let mut parallel_edge_count = HashMap::new(); for (i, j, weight) in get_edge_iter_with_weights(&graph.graph) { let edge_weight = weight_callable(py, &weight_fn, &weight, default_weight)?; if matrix[[i, j]] == null_value || (null_value.is_nan() && matrix[[i, j]].is_nan()) { matrix[[i, j]] = edge_weight; matrix[[j, i]] = edge_weight; } else { - matrix[[i, j]] += edge_weight; - matrix[[j, i]] += edge_weight; + match parallel_edge { + "sum" => { + matrix[[i, j]] += edge_weight; + matrix[[j, i]] += edge_weight; + } + "min" => { + let weight_min = matrix[[i, j]].min(edge_weight); + matrix[[i, j]] = weight_min; + matrix[[j, i]] = weight_min; + } + "max" => { + let weight_max = matrix[[i, j]].max(edge_weight); + matrix[[i, j]] = weight_max; + matrix[[j, i]] = weight_max; + } + "avg" => { + if parallel_edge_count.contains_key(&[i, j]) { + matrix[[i, j]] = (matrix[[i, j]] * parallel_edge_count[&[i, j]] as f64 + + edge_weight) + / ((parallel_edge_count[&[i, j]] + 1) as f64); + matrix[[j, i]] = (matrix[[j, i]] * parallel_edge_count[&[i, j]] as f64 + + edge_weight) + / ((parallel_edge_count[&[i, j]] + 1) as f64); + *parallel_edge_count.get_mut(&[i, j]).unwrap() += 1; + } else { + parallel_edge_count.insert([i, j], 2); + matrix[[i, j]] = (matrix[[i, j]] + edge_weight) / 2.0; + matrix[[j, i]] = (matrix[[j, i]] + edge_weight) / 2.0; + } + } + _ => { + return Err(PyValueError::new_err("Parallel edges can currently only be dealt with using \"sum\", \"min\", \"max\", or \"avg\".")); + } + } } } Ok(matrix.into_pyarray(py).into()) diff --git a/tests/rustworkx_tests/digraph/test_adjacency_matrix.py b/tests/rustworkx_tests/digraph/test_adjacency_matrix.py index f2f58488a..38d998559 100644 --- a/tests/rustworkx_tests/digraph/test_adjacency_matrix.py +++ b/tests/rustworkx_tests/digraph/test_adjacency_matrix.py @@ -262,3 +262,54 @@ def test_nan_null(self): edge_list, [(0, 1, 1 + 0j), (1, 0, 1 + 0j), (1, 2, 1 + 0j), (2, 1, 1 + 0j)], ) + + def test_parallel_edge(self): + graph = rustworkx.PyDiGraph() + a = graph.add_node("A") + b = graph.add_node("B") + c = graph.add_node("C") + + graph.add_edges_from( + [ + (a, b, 3.0), + (a, b, 1.0), + (a, c, 2.0), + (b, c, 7.0), + (c, a, 1.0), + (b, c, 2.0), + (a, b, 4.0), + ] + ) + + min_matrix = rustworkx.digraph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="min" + ) + np.testing.assert_array_equal( + [[0.0, 1.0, 2.0], [0.0, 0.0, 2.0], [1.0, 0.0, 0.0]], min_matrix + ) + + max_matrix = rustworkx.digraph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="max" + ) + np.testing.assert_array_equal( + [[0.0, 4.0, 2.0], [0.0, 0.0, 7.0], [1.0, 0.0, 0.0]], max_matrix + ) + + avg_matrix = rustworkx.digraph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="avg" + ) + np.testing.assert_array_equal( + [[0.0, 8 / 3.0, 2.0], [0.0, 0.0, 4.5], [1.0, 0.0, 0.0]], avg_matrix + ) + + sum_matrix = rustworkx.digraph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="sum" + ) + np.testing.assert_array_equal( + [[0.0, 8.0, 2.0], [0.0, 0.0, 9.0], [1.0, 0.0, 0.0]], sum_matrix + ) + + with self.assertRaises(ValueError): + rustworkx.digraph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="error" + ) diff --git a/tests/rustworkx_tests/graph/test_adjencency_matrix.py b/tests/rustworkx_tests/graph/test_adjencency_matrix.py index 30934dfff..d303c7550 100644 --- a/tests/rustworkx_tests/graph/test_adjencency_matrix.py +++ b/tests/rustworkx_tests/graph/test_adjencency_matrix.py @@ -260,3 +260,54 @@ def test_nan_null(self): edge_list, [(0, 1, 1 + 0j), (1, 2, 1 + 0j)], ) + + def test_parallel_edge(self): + graph = rustworkx.PyGraph() + a = graph.add_node("A") + b = graph.add_node("B") + c = graph.add_node("C") + + graph.add_edges_from( + [ + (a, b, 3.0), + (a, b, 1.0), + (a, c, 2.0), + (b, c, 7.0), + (c, a, 1.0), + (b, c, 2.0), + (a, b, 4.0), + ] + ) + + min_matrix = rustworkx.graph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="min" + ) + np.testing.assert_array_equal( + [[0.0, 1.0, 1.0], [1.0, 0.0, 2.0], [1.0, 2.0, 0.0]], min_matrix + ) + + max_matrix = rustworkx.graph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="max" + ) + np.testing.assert_array_equal( + [[0.0, 4.0, 2.0], [4.0, 0.0, 7.0], [2.0, 7.0, 0.0]], max_matrix + ) + + avg_matrix = rustworkx.graph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="avg" + ) + np.testing.assert_array_equal( + [[0.0, 8 / 3.0, 1.5], [8 / 3.0, 0.0, 4.5], [1.5, 4.5, 0.0]], avg_matrix + ) + + sum_matrix = rustworkx.graph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="sum" + ) + np.testing.assert_array_equal( + [[0.0, 8.0, 3.0], [8.0, 0.0, 9.0], [3.0, 9.0, 0.0]], sum_matrix + ) + + with self.assertRaises(ValueError): + rustworkx.graph_adjacency_matrix( + graph, weight_fn=lambda x: float(x), parallel_edge="error" + ) From c4f0c2e239dc45c1a5c51fc4fe0cc42be5ea9b78 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Wed, 21 Jun 2023 07:32:17 -0400 Subject: [PATCH 37/37] Drop support for python 3.7 for 0.14.0 (#907) * Initial: Bump python to 3.8 and rustc to 1.63.0 * CI: Testing new workflows * CI: Second Attempt * CI: Third attempt * CI: 4th Attempt * CI: Fifth attempt, skip cp37. * CI: Sixth attempt, restored rustc to 1.56.1 * Docs: Add release note * Chore: Remove comments on main workflow * Update releasenotes/notes/drop-python-3.7-c71f8b6559beaf86.yaml * Remove Python3.7 from mergify config --------- Co-authored-by: Matthew Treinish --- .github/workflows/main.yml | 2 +- .github/workflows/wheels.yml | 26 +++++++++---------- .mergify.yml | 3 --- .../drop-python-3.7-c71f8b6559beaf86.yaml | 5 ++++ setup.py | 3 +-- 5 files changed, 20 insertions(+), 19 deletions(-) create mode 100644 releasenotes/notes/drop-python-3.7-c71f8b6559beaf86.yaml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1f8642482..007308e5f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -57,7 +57,7 @@ jobs: strategy: matrix: rust: [stable] - python-version: ['3.7.16', 3.8, 3.9, "3.10", "3.11"] + python-version: [3.8, 3.9, "3.10", "3.11"] platform: [ { os: "macOS-latest", python-architecture: "x64", rust-target: "x86_64-apple-darwin" }, { os: "ubuntu-latest", python-architecture: "x64", rust-target: "x86_64-unknown-linux-gnu" }, diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 6780979d4..4312f16de 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -26,7 +26,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - name: Install deps run: pip install -U twine setuptools-rust - name: Build sdist @@ -51,7 +51,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Install cibuildwheel run: | @@ -65,7 +65,7 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp37-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests @@ -89,7 +89,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU uses: docker/setup-qemu-action@v2 @@ -107,7 +107,7 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp37-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx scipy testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests @@ -132,7 +132,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU uses: docker/setup-qemu-action@v2 @@ -150,7 +150,7 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* cp39-* cp310-* cp311-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp37-* cp39-* cp310-* cp311-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests @@ -175,7 +175,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU uses: docker/setup-qemu-action@v2 @@ -218,7 +218,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU uses: docker/setup-qemu-action@v2 @@ -236,7 +236,7 @@ jobs: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64:latest CIBW_MANYLINUX_I686_IMAGE: quay.io/pypa/manylinux2014_i686:latest - CIBW_SKIP: cp36-* cp39-* cp310-* cp311-* pp* *win32 *musl* + CIBW_SKIP: cp36-* cp37-* cp39-* cp310-* cp311-* pp* *win32 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests @@ -262,7 +262,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' - uses: dtolnay/rust-toolchain@stable - name: Set up QEMU uses: docker/setup-qemu-action@v2 @@ -329,7 +329,7 @@ jobs: - uses: actions/setup-python@v4 name: Install Python with: - python-version: '3.7' + python-version: '3.8' architecture: 'x86' - uses: dtolnay/rust-toolchain@stable with: @@ -344,7 +344,7 @@ jobs: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=10.9 - CIBW_SKIP: cp36-* pp* *amd64 *musl* + CIBW_SKIP: cp36-* cp37-* pp* *amd64 *musl* CIBW_BEFORE_BUILD: pip install -U setuptools-rust CIBW_TEST_REQUIRES: networkx testtools fixtures CIBW_TEST_COMMAND: python -m unittest discover {project}/tests/rustworkx_tests diff --git a/.mergify.yml b/.mergify.yml index f599bb38e..3deb83e61 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -1,9 +1,6 @@ queue_rules: - name: automerge conditions: - - check-success=python3.7-x64 windows-latest - - check-success=python3.7-x64 ubuntu-latest - - check-success=python3.7-x64 macOS-latest - check-success=python3.8-x64 windows-latest - check-success=python3.8-x64 ubuntu-latest - check-success=python3.8-x64 macOS-latest diff --git a/releasenotes/notes/drop-python-3.7-c71f8b6559beaf86.yaml b/releasenotes/notes/drop-python-3.7-c71f8b6559beaf86.yaml new file mode 100644 index 000000000..acb680b71 --- /dev/null +++ b/releasenotes/notes/drop-python-3.7-c71f8b6559beaf86.yaml @@ -0,0 +1,5 @@ +--- +upgrade: + - | + The minimum required Python version was raised to Python 3.8. + To use rustworkx, please ensure you are using Python >= 3.8. \ No newline at end of file diff --git a/setup.py b/setup.py index 6e8121ce2..815f17538 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,6 @@ def readme(): "Intended Audience :: Science/Research", "Programming Language :: Rust", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -84,7 +83,7 @@ def readme(): include_package_data=True, packages=PKG_PACKAGES, zip_safe=False, - python_requires=">=3.7", + python_requires=">=3.8", install_requires=PKG_INSTALL_REQUIRES, extras_require={ 'mpl': mpl_extras,