Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read gzipped graphml files #1315

Merged
merged 17 commits into from
Nov 19, 2024
Merged
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
smallvec = { version = "1.0", features = ["union"] }
rustworkx-core = { path = "rustworkx-core", version = "=0.16.0" }
flate2 = "1.0.35"

[dependencies.pyo3]
version = "0.22.6"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Added the ability to read GraphML files that are compressed using gzip, with function :func:`~rustworkx.read_graphml`.
The extensions `.graphmlz` and `.gz` are automatically recognised, but the gzip decompression can be forced with the "compression" optional argument.
6 changes: 5 additions & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,11 @@ def directed_random_bipartite_graph(

# Read Write

def read_graphml(path: str, /) -> list[PyGraph | PyDiGraph]: ...
def read_graphml(
path: str,
/,
compression: str | None = ...,
) -> list[PyGraph | PyDiGraph]: ...
def digraph_node_link_json(
graph: PyDiGraph[_S, _T],
/,
Expand Down
47 changes: 40 additions & 7 deletions src/graphml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
#![allow(clippy::borrow_as_ptr)]

use std::convert::From;
use std::ffi::OsStr;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::iter::FromIterator;
use std::num::{ParseFloatError, ParseIntError};
use std::path::Path;
use std::str::ParseBoolError;

use flate2::bufread::GzDecoder;
use hashbrown::HashMap;
use indexmap::IndexMap;

Expand Down Expand Up @@ -524,19 +528,27 @@ impl GraphML {

Ok(())
}
/// Open file compressed with gzip, using the GzDecoder
/// Returns a quick_xml Reader instance
fn open_file_gzip<P: AsRef<Path>>(
path: P,
) -> Result<Reader<BufReader<GzDecoder<BufReader<File>>>>, quick_xml::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let gzip_reader = BufReader::new(GzDecoder::new(reader));
Ok(Reader::from_reader(gzip_reader))
}

/// Parse a file written in GraphML format.
/// Parse a file written in GraphML format from a BufReader
///
/// The implementation is based on a state machine in order to
/// accept only valid GraphML syntax (e.g a `<data>` element should
/// be nested inside a `<node>` element) where the internal state changes
/// after handling each quick_xml event.
fn from_file<P: AsRef<Path>>(path: P) -> Result<GraphML, Error> {
fn read_graph_from_reader<R: BufRead>(mut reader: Reader<R>) -> Result<GraphML, Error> {
let mut graphml = GraphML::default();

let mut buf = Vec::new();
let mut reader = Reader::from_file(path)?;

let mut state = State::Start;
let mut domain_of_last_key = Domain::Node;
let mut last_data_key = String::new();
Expand Down Expand Up @@ -677,6 +689,23 @@ impl GraphML {

Ok(graphml)
}

/// Read a graph from a file in the GraphML format
/// If the the file extension is "graphmlz" or "gz", decompress it on the fly
fn from_file<P: AsRef<Path>>(path: P, compression: &str) -> Result<GraphML, Error> {
let extension = path.as_ref().extension().unwrap_or(OsStr::new(""));

let graph: Result<GraphML, Error> =
if extension.eq("graphmlz") || extension.eq("gz") || compression.eq("gzip") {
let reader = Self::open_file_gzip(path)?;
Self::read_graph_from_reader(reader)
} else {
let reader = Reader::from_file(path)?;
Self::read_graph_from_reader(reader)
};

graph
}
}

/// Read a list of graphs from a file in GraphML format.
Expand All @@ -703,9 +732,13 @@ impl GraphML {
/// :rtype: list[Union[PyGraph, PyDiGraph]]
/// :raises RuntimeError: when an error is encountered while parsing the GraphML file.
#[pyfunction]
#[pyo3(text_signature = "(path, /)")]
pub fn read_graphml(py: Python, path: &str) -> PyResult<Vec<PyObject>> {
let graphml = GraphML::from_file(path)?;
#[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")]
pub fn read_graphml(
py: Python,
path: &str,
compression: Option<String>,
) -> PyResult<Vec<PyObject>> {
let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?;

let mut out = Vec::new();
for graph in graphml.graphs {
Expand Down
55 changes: 53 additions & 2 deletions tests/test_graphml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import unittest
import tempfile
import gzip

import numpy

import rustworkx
Expand Down Expand Up @@ -55,8 +57,8 @@ def assertGraphMLRaises(self, graph_xml):
with self.assertRaises(Exception):
rustworkx.read_graphml(fd.name)

def test_simple(self):
graph_xml = self.HEADER.format(
def graphml_xml_example(self):
return self.HEADER.format(
"""
<key id="d0" for="node" attr.name="color" attr.type="string">
<default>yellow</default>
Expand All @@ -80,6 +82,8 @@ def test_simple(self):
"""
)

def test_simple(self):
graph_xml = self.graphml_xml_example()
with tempfile.NamedTemporaryFile("wt") as fd:
fd.write(graph_xml)
fd.flush()
Expand All @@ -96,6 +100,53 @@ def test_simple(self):
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_gzipped(self):
graph_xml = self.graphml_xml_example()

## Test reading a graphmlz
with tempfile.NamedTemporaryFile("w+b") as fd:
fd.flush()
newname = fd.name + ".gz"
with gzip.open(newname, "wt") as wf:
wf.write(graph_xml)

graphml = rustworkx.read_graphml(newname)
graph = graphml[0]
nodes = [
{"id": "n0", "color": "blue"},
{"id": "n1", "color": "yellow"},
{"id": "n2", "color": "green"},
]
edges = [
("n0", "n1", {"fidelity": 0.98}),
("n0", "n2", {"fidelity": 0.95}),
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_gzipped_force(self):
graph_xml = self.graphml_xml_example()

## Test reading a graphmlz
with tempfile.NamedTemporaryFile("w+b") as fd:
# close the file
fd.flush()
newname = fd.name + ".ext"
with gzip.open(newname, "wt") as wf:
wf.write(graph_xml)

graphml = rustworkx.read_graphml(newname, compression="gzip")
graph = graphml[0]
nodes = [
{"id": "n0", "color": "blue"},
{"id": "n1", "color": "yellow"},
{"id": "n2", "color": "green"},
]
edges = [
("n0", "n1", {"fidelity": 0.98}),
("n0", "n2", {"fidelity": 0.95}),
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_multiple_graphs_in_single_file(self):
graph_xml = self.HEADER.format(
"""
Expand Down