diff --git a/Cargo.lock b/Cargo.lock index fd443490c0..076dd27234 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,7 +741,7 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustworkx" -version = "0.17.0" +version = "0.17.1" dependencies = [ "fixedbitset", "flate2", @@ -770,7 +770,7 @@ dependencies = [ [[package]] name = "rustworkx-core" -version = "0.17.0" +version = "0.17.1" dependencies = [ "fixedbitset", "foldhash", diff --git a/Cargo.toml b/Cargo.toml index fb7d64731c..7ac526c62f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ members = [ ] [workspace.package] -version = "0.17.0" +version = "0.17.1" edition = "2021" rust-version = "1.79" authors = ["Matthew Treinish "] @@ -62,7 +62,7 @@ rayon.workspace = true serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" smallvec = { version = "1.0", features = ["union"] } -rustworkx-core = { path = "rustworkx-core", version = "=0.17.0" } +rustworkx-core = { path = "rustworkx-core", version = "=0.17.1" } flate2 = "1.0.35" [dependencies.pyo3] diff --git a/docs/source/conf.py b/docs/source/conf.py index 8d762cf502..24bdf3b731 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version. version = '0.17' # The full version, including alpha/beta/rc tags. -release = '0.17.0' +release = '0.17.1' extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', diff --git a/pyproject.toml b/pyproject.toml index c54323c3d0..f86568d2f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rustworkx" -version = "0.17.0" +version = "0.17.1" description = "A High-Performance Graph Library for Python" requires-python = ">=3.9" dependencies = [ @@ -22,9 +22,8 @@ classifiers=[ "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", - "Development Status :: 5 - Production/Stable", ] -keywords = ["Networks", "network", "graph", "Graph Theory", "DAG"] +keywords = ["Networks", "Network", "Graph", "Graph Theory", "DAG"] [tool.setuptools] packages = ["rustworkx", "rustworkx.visualization"] diff --git a/releasenotes/notes/add-distance-matrix-8cbe417d6f4eaf6d.yaml b/releasenotes/notes/add-distance-matrix-8cbe417d6f4eaf6d.yaml new file mode 100644 index 0000000000..6ae96cc216 --- /dev/null +++ b/releasenotes/notes/add-distance-matrix-8cbe417d6f4eaf6d.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function ``rustworkx_core::shortest_path::distance_matrix`` + to rustworkx-core. This function is the equivalent of :func:`.distance_matrix` + for the Python library, but as a generic Rust function for rustworkx-core. diff --git a/rustworkx-core/src/shortest_path/distance_matrix.rs b/rustworkx-core/src/shortest_path/distance_matrix.rs new file mode 100644 index 0000000000..82358ed9cf --- /dev/null +++ b/rustworkx-core/src/shortest_path/distance_matrix.rs @@ -0,0 +1,159 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::hash::Hash; + +use hashbrown::HashMap; + +use fixedbitset::FixedBitSet; +use ndarray::prelude::*; +use petgraph::visit::{ + GraphProp, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, NodeIndexable, +}; +use petgraph::{Incoming, Outgoing}; +use rayon::prelude::*; + +/// Get the distance matrix for a graph +/// +/// The generated distance matrix assumes the edge weight for all edges is +/// 1.0 and returns a matrix. +/// +/// This function is also multithreaded and will run in parallel if the number +/// of nodes in the graph is above the value of `parallel_threshold`. If the function +/// will be running in parallel the env var +/// `RAYON_NUM_THREADS` can be used to adjust how many threads will be used. +/// +/// # Arguments: +/// +/// * graph - The graph object to compute the distance matrix for. +/// * parallel_threshold - The threshold in number of nodes to run this function in parallel. +/// If `graph` has fewer nodes than this the algorithm will run serially. A good default +/// to use for this is 300. +/// * as_undirected - If the input graph is directed and this is set to true the output +/// matrix generated +/// * null_value - The value to use for the absence of a path in the graph. +/// +/// # Returns +/// +/// A 2d ndarray [`Array`] of the distance matrix +/// +/// # Example +/// +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::shortest_path::distance_matrix; +/// use ndarray::{array, Array2}; +/// +/// let graph = petgraph::graph::UnGraph::<(), ()>::from_edges(&[ +/// (0, 1), (0, 6), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6) +/// ]); +/// let distance_matrix = distance_matrix(&graph, 300, false, 0.); +/// let expected: Array2 = array![ +/// [0.0, 1.0, 2.0, 3.0, 3.0, 2.0, 1.0], +/// [1.0, 0.0, 1.0, 2.0, 3.0, 3.0, 2.0], +/// [2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 3.0], +/// [3.0, 2.0, 1.0, 0.0, 1.0, 2.0, 3.0], +/// [3.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0], +/// [2.0, 3.0, 3.0, 2.0, 1.0, 0.0, 1.0], +/// [1.0, 2.0, 3.0, 3.0, 2.0, 1.0, 0.0], +/// ]; +/// assert_eq!(distance_matrix, expected) +/// ``` +pub fn distance_matrix( + graph: G, + parallel_threshold: usize, + as_undirected: bool, + null_value: f64, +) -> Array2 +where + G: Sync + IntoNeighborsDirected + NodeCount + NodeIndexable + IntoNodeIdentifiers + GraphProp, + G::NodeId: Hash + Eq + Sync, +{ + let n = graph.node_count(); + let node_map: HashMap = if n != graph.node_bound() { + graph + .node_identifiers() + .enumerate() + .map(|(i, v)| (v, i)) + .collect() + } else { + HashMap::new() + }; + let node_map_inv: Vec = if n != graph.node_bound() { + graph.node_identifiers().collect() + } else { + Vec::new() + }; + let mut node_map_fn: Box usize> = if n != graph.node_bound() { + Box::new(|n: G::NodeId| -> usize { node_map[&n] }) + } else { + Box::new(|n: G::NodeId| -> usize { graph.to_index(n) }) + }; + let mut reverse_node_map: Box G::NodeId> = if n != graph.node_bound() { + Box::new(|n: usize| -> G::NodeId { node_map_inv[n] }) + } else { + Box::new(|n: usize| -> G::NodeId { graph.from_index(n) }) + }; + let mut matrix = Array2::::from_elem((n, n), null_value); + let neighbors = if as_undirected { + (0..n) + .map(|index| { + graph + .neighbors_directed(reverse_node_map(index), Incoming) + .chain(graph.neighbors_directed(reverse_node_map(index), Outgoing)) + .map(&mut node_map_fn) + .collect::() + }) + .collect::>() + } else { + (0..n) + .map(|index| { + graph + .neighbors(reverse_node_map(index)) + .map(&mut node_map_fn) + .collect::() + }) + .collect::>() + }; + let bfs_traversal = |start: usize, mut row: ArrayViewMut1| { + let mut distance = 0.0; + let mut seen = FixedBitSet::with_capacity(n); + let mut next = FixedBitSet::with_capacity(n); + let mut cur = FixedBitSet::with_capacity(n); + cur.put(start); + while !cur.is_clear() { + next.clear(); + for found in cur.ones() { + row[[found]] = distance; + next |= &neighbors[found]; + } + seen.union_with(&cur); + next.difference_with(&seen); + distance += 1.0; + ::std::mem::swap(&mut cur, &mut next); + } + }; + if n < parallel_threshold { + matrix + .axis_iter_mut(Axis(0)) + .enumerate() + .for_each(|(index, row)| bfs_traversal(index, row)); + } else { + // Parallelize by row and iterate from each row index in BFS order + matrix + .axis_iter_mut(Axis(0)) + .into_par_iter() + .enumerate() + .for_each(|(index, row)| bfs_traversal(index, row)); + } + matrix +} diff --git a/rustworkx-core/src/shortest_path/mod.rs b/rustworkx-core/src/shortest_path/mod.rs index 214fff8beb..b00e971140 100644 --- a/rustworkx-core/src/shortest_path/mod.rs +++ b/rustworkx-core/src/shortest_path/mod.rs @@ -19,6 +19,7 @@ mod all_shortest_paths; mod astar; mod bellman_ford; mod dijkstra; +mod distance_matrix; mod k_shortest_path; mod single_source_all_shortest_paths; @@ -26,5 +27,6 @@ pub use all_shortest_paths::all_shortest_paths; pub use astar::astar; pub use bellman_ford::{bellman_ford, negative_cycle_finder}; pub use dijkstra::dijkstra; +pub use distance_matrix::distance_matrix; pub use k_shortest_path::k_shortest_path; pub use single_source_all_shortest_paths::single_source_all_shortest_paths; diff --git a/src/shortest_path/distance_matrix.rs b/src/shortest_path/distance_matrix.rs index 2f417b23f4..ead0d36d3a 100644 --- a/src/shortest_path/distance_matrix.rs +++ b/src/shortest_path/distance_matrix.rs @@ -10,33 +10,12 @@ // License for the specific language governing permissions and limitations // under the License. -use std::ops::Index; - -use hashbrown::{HashMap, HashSet}; - use ndarray::prelude::*; -use petgraph::prelude::*; use petgraph::EdgeType; -use rayon::prelude::*; -use crate::NodesRemoved; use crate::StablePyGraph; -#[inline] -fn apply( - map_fn: &Option, - x: I, - default: >::Output, -) -> >::Output -where - M: Index, - >::Output: Sized + Copy, -{ - match map_fn { - Some(map) => map[x], - None => default, - } -} +use rustworkx_core::shortest_path; pub fn compute_distance_matrix( graph: &StablePyGraph, @@ -44,71 +23,5 @@ pub fn compute_distance_matrix( as_undirected: bool, null_value: f64, ) -> Array2 { - let node_map: Option> = if graph.nodes_removed() { - Some( - graph - .node_indices() - .enumerate() - .map(|(i, v)| (v, i)) - .collect(), - ) - } else { - None - }; - - let node_map_inv: Option> = if graph.nodes_removed() { - Some(graph.node_indices().collect()) - } else { - None - }; - - let n = graph.node_count(); - let mut matrix = Array2::::from_elem((n, n), null_value); - let bfs_traversal = |index: usize, mut row: ArrayViewMut1| { - let mut seen: HashMap = HashMap::with_capacity(n); - let start_index = apply(&node_map_inv, index, NodeIndex::new(index)); - let mut level = 0; - let mut next_level: HashSet = HashSet::new(); - next_level.insert(start_index); - while !next_level.is_empty() { - let this_level = next_level; - next_level = HashSet::new(); - let mut found: Vec = Vec::new(); - for v in this_level { - if !seen.contains_key(&v) { - seen.insert(v, level); - found.push(v); - row[[apply(&node_map, &v, v.index())]] = level as f64; - } - } - if seen.len() == n { - return; - } - for node in found { - for v in graph.neighbors_directed(node, petgraph::Direction::Outgoing) { - next_level.insert(v); - } - if graph.is_directed() && as_undirected { - for v in graph.neighbors_directed(node, petgraph::Direction::Incoming) { - next_level.insert(v); - } - } - } - level += 1 - } - }; - if n < parallel_threshold { - matrix - .axis_iter_mut(Axis(0)) - .enumerate() - .for_each(|(index, row)| bfs_traversal(index, row)); - } else { - // Parallelize by row and iterate from each row index in BFS order - matrix - .axis_iter_mut(Axis(0)) - .into_par_iter() - .enumerate() - .for_each(|(index, row)| bfs_traversal(index, row)); - } - matrix + shortest_path::distance_matrix(graph, parallel_threshold, as_undirected, null_value) } diff --git a/uv.lock b/uv.lock index ad53904928..d63102197d 100644 --- a/uv.lock +++ b/uv.lock @@ -2602,7 +2602,7 @@ wheels = [ [[package]] name = "rustworkx" -version = "0.17.0" +version = "0.17.1" source = { editable = "." } dependencies = [ { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },