From 0017d000e6bf82aed0da095410de6f4a4596fa4d Mon Sep 17 00:00:00 2001 From: Fabio Mazza Date: Tue, 19 Nov 2024 23:38:57 +0100 Subject: [PATCH] Read gzipped graphml files (#1315) * Use flate2 crate to read gzipped graphml files * fix typo * run rustfmt * apply suggestion from clippy * add test for gzipped graphml * write separate function * add changelog * reformat * Revert "write separate function" This reverts commit 2dba2529004f6cb7424eb27cbae319c12d6e4bf6. * run with compression argument * update contribution * lint python * add stub * try avoid error in test in Windows * use Option for compression variable * correct text signature --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --- Cargo.lock | 35 ++++++++++++ Cargo.toml | 1 + ...t-description-string-564c7e376b8e7304.yaml | 5 ++ rustworkx/rustworkx.pyi | 6 +- src/graphml.rs | 47 +++++++++++++--- tests/test_graphml.py | 55 ++++++++++++++++++- 6 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 releasenotes/notes/short-description-string-564c7e376b8e7304.yaml diff --git a/Cargo.lock b/Cargo.lock index 45fd8248a7..457b25b0be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -39,6 +45,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -82,6 +97,16 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -194,6 +219,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -529,6 +563,7 @@ version = "0.16.0" dependencies = [ "ahash", "fixedbitset", + "flate2", "hashbrown 0.14.5", "indexmap", "ndarray", diff --git a/Cargo.toml b/Cargo.toml index 3bffc9f379..a07c76d632 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" smallvec = { version = "1.0", features = ["union"] } rustworkx-core = { path = "rustworkx-core", version = "=0.16.0" } +flate2 = "1.0.35" [dependencies.pyo3] version = "0.22.6" diff --git a/releasenotes/notes/short-description-string-564c7e376b8e7304.yaml b/releasenotes/notes/short-description-string-564c7e376b8e7304.yaml new file mode 100644 index 0000000000..5ddc62d189 --- /dev/null +++ b/releasenotes/notes/short-description-string-564c7e376b8e7304.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Added the ability to read GraphML files that are compressed using gzip, with function :func:`~rustworkx.read_graphml`. + The extensions `.graphmlz` and `.gz` are automatically recognised, but the gzip decompression can be forced with the "compression" optional argument. diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 69dcac8dd1..beebd50419 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -646,7 +646,11 @@ def directed_random_bipartite_graph( # Read Write -def read_graphml(path: str, /) -> list[PyGraph | PyDiGraph]: ... +def read_graphml( + path: str, + /, + compression: str | None = ..., +) -> list[PyGraph | PyDiGraph]: ... def digraph_node_link_json( graph: PyDiGraph[_S, _T], /, diff --git a/src/graphml.rs b/src/graphml.rs index 6211b25ce2..89c71b79d3 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -13,11 +13,15 @@ #![allow(clippy::borrow_as_ptr)] use std::convert::From; +use std::ffi::OsStr; +use std::fs::File; +use std::io::{BufRead, BufReader}; use std::iter::FromIterator; use std::num::{ParseFloatError, ParseIntError}; use std::path::Path; use std::str::ParseBoolError; +use flate2::bufread::GzDecoder; use hashbrown::HashMap; use indexmap::IndexMap; @@ -524,19 +528,27 @@ impl GraphML { Ok(()) } + /// Open file compressed with gzip, using the GzDecoder + /// Returns a quick_xml Reader instance + fn open_file_gzip>( + path: P, + ) -> Result>>>, quick_xml::Error> { + let file = File::open(path)?; + let reader = BufReader::new(file); + let gzip_reader = BufReader::new(GzDecoder::new(reader)); + Ok(Reader::from_reader(gzip_reader)) + } - /// Parse a file written in GraphML format. + /// Parse a file written in GraphML format from a BufReader /// /// The implementation is based on a state machine in order to /// accept only valid GraphML syntax (e.g a `` element should /// be nested inside a `` element) where the internal state changes /// after handling each quick_xml event. - fn from_file>(path: P) -> Result { + fn read_graph_from_reader(mut reader: Reader) -> Result { let mut graphml = GraphML::default(); let mut buf = Vec::new(); - let mut reader = Reader::from_file(path)?; - let mut state = State::Start; let mut domain_of_last_key = Domain::Node; let mut last_data_key = String::new(); @@ -677,6 +689,23 @@ impl GraphML { Ok(graphml) } + + /// Read a graph from a file in the GraphML format + /// If the the file extension is "graphmlz" or "gz", decompress it on the fly + fn from_file>(path: P, compression: &str) -> Result { + let extension = path.as_ref().extension().unwrap_or(OsStr::new("")); + + let graph: Result = + if extension.eq("graphmlz") || extension.eq("gz") || compression.eq("gzip") { + let reader = Self::open_file_gzip(path)?; + Self::read_graph_from_reader(reader) + } else { + let reader = Reader::from_file(path)?; + Self::read_graph_from_reader(reader) + }; + + graph + } } /// Read a list of graphs from a file in GraphML format. @@ -703,9 +732,13 @@ impl GraphML { /// :rtype: list[Union[PyGraph, PyDiGraph]] /// :raises RuntimeError: when an error is encountered while parsing the GraphML file. #[pyfunction] -#[pyo3(text_signature = "(path, /)")] -pub fn read_graphml(py: Python, path: &str) -> PyResult> { - let graphml = GraphML::from_file(path)?; +#[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")] +pub fn read_graphml( + py: Python, + path: &str, + compression: Option, +) -> PyResult> { + let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?; let mut out = Vec::new(); for graph in graphml.graphs { diff --git a/tests/test_graphml.py b/tests/test_graphml.py index fee85da4ad..517a79d263 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -12,6 +12,8 @@ import unittest import tempfile +import gzip + import numpy import rustworkx @@ -55,8 +57,8 @@ def assertGraphMLRaises(self, graph_xml): with self.assertRaises(Exception): rustworkx.read_graphml(fd.name) - def test_simple(self): - graph_xml = self.HEADER.format( + def graphml_xml_example(self): + return self.HEADER.format( """ yellow @@ -80,6 +82,8 @@ def test_simple(self): """ ) + def test_simple(self): + graph_xml = self.graphml_xml_example() with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() @@ -96,6 +100,53 @@ def test_simple(self): ] self.assertGraphEqual(graph, nodes, edges, directed=False) + def test_gzipped(self): + graph_xml = self.graphml_xml_example() + + ## Test reading a graphmlz + with tempfile.NamedTemporaryFile("w+b") as fd: + fd.flush() + newname = fd.name + ".gz" + with gzip.open(newname, "wt") as wf: + wf.write(graph_xml) + + graphml = rustworkx.read_graphml(newname) + graph = graphml[0] + nodes = [ + {"id": "n0", "color": "blue"}, + {"id": "n1", "color": "yellow"}, + {"id": "n2", "color": "green"}, + ] + edges = [ + ("n0", "n1", {"fidelity": 0.98}), + ("n0", "n2", {"fidelity": 0.95}), + ] + self.assertGraphEqual(graph, nodes, edges, directed=False) + + def test_gzipped_force(self): + graph_xml = self.graphml_xml_example() + + ## Test reading a graphmlz + with tempfile.NamedTemporaryFile("w+b") as fd: + # close the file + fd.flush() + newname = fd.name + ".ext" + with gzip.open(newname, "wt") as wf: + wf.write(graph_xml) + + graphml = rustworkx.read_graphml(newname, compression="gzip") + graph = graphml[0] + nodes = [ + {"id": "n0", "color": "blue"}, + {"id": "n1", "color": "yellow"}, + {"id": "n2", "color": "green"}, + ] + edges = [ + ("n0", "n1", {"fidelity": 0.98}), + ("n0", "n2", {"fidelity": 0.95}), + ] + self.assertGraphEqual(graph, nodes, edges, directed=False) + def test_multiple_graphs_in_single_file(self): graph_xml = self.HEADER.format( """