Skip to content

Commit

Permalink
Read gzipped graphml files (#1315)
Browse files Browse the repository at this point in the history
* Use flate2 crate to read gzipped graphml files

* fix typo

* run rustfmt

* apply suggestion from clippy

* add test for gzipped graphml

* write separate function

* add changelog

* reformat

* Revert "write separate function"

This reverts commit 2dba252.

* run with compression argument

* update contribution

* lint python

* add stub

* try avoid error in test in Windows

* use Option for compression variable

* correct text signature

---------

Co-authored-by: Ivan Carvalho <[email protected]>
  • Loading branch information
fabmazz and IvanIsCoding authored Nov 19, 2024
1 parent ff611f2 commit 0017d00
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 10 deletions.
35 changes: 35 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
smallvec = { version = "1.0", features = ["union"] }
rustworkx-core = { path = "rustworkx-core", version = "=0.16.0" }
flate2 = "1.0.35"

[dependencies.pyo3]
version = "0.22.6"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Added the ability to read GraphML files that are compressed using gzip, with function :func:`~rustworkx.read_graphml`.
The extensions `.graphmlz` and `.gz` are automatically recognised, but the gzip decompression can be forced with the "compression" optional argument.
6 changes: 5 additions & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,11 @@ def directed_random_bipartite_graph(

# Read Write

def read_graphml(path: str, /) -> list[PyGraph | PyDiGraph]: ...
def read_graphml(
path: str,
/,
compression: str | None = ...,
) -> list[PyGraph | PyDiGraph]: ...
def digraph_node_link_json(
graph: PyDiGraph[_S, _T],
/,
Expand Down
47 changes: 40 additions & 7 deletions src/graphml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
#![allow(clippy::borrow_as_ptr)]

use std::convert::From;
use std::ffi::OsStr;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::iter::FromIterator;
use std::num::{ParseFloatError, ParseIntError};
use std::path::Path;
use std::str::ParseBoolError;

use flate2::bufread::GzDecoder;
use hashbrown::HashMap;
use indexmap::IndexMap;

Expand Down Expand Up @@ -524,19 +528,27 @@ impl GraphML {

Ok(())
}
/// Open file compressed with gzip, using the GzDecoder
/// Returns a quick_xml Reader instance
fn open_file_gzip<P: AsRef<Path>>(
path: P,
) -> Result<Reader<BufReader<GzDecoder<BufReader<File>>>>, quick_xml::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let gzip_reader = BufReader::new(GzDecoder::new(reader));
Ok(Reader::from_reader(gzip_reader))
}

/// Parse a file written in GraphML format.
/// Parse a file written in GraphML format from a BufReader
///
/// The implementation is based on a state machine in order to
/// accept only valid GraphML syntax (e.g a `<data>` element should
/// be nested inside a `<node>` element) where the internal state changes
/// after handling each quick_xml event.
fn from_file<P: AsRef<Path>>(path: P) -> Result<GraphML, Error> {
fn read_graph_from_reader<R: BufRead>(mut reader: Reader<R>) -> Result<GraphML, Error> {
let mut graphml = GraphML::default();

let mut buf = Vec::new();
let mut reader = Reader::from_file(path)?;

let mut state = State::Start;
let mut domain_of_last_key = Domain::Node;
let mut last_data_key = String::new();
Expand Down Expand Up @@ -677,6 +689,23 @@ impl GraphML {

Ok(graphml)
}

/// Read a graph from a file in the GraphML format
/// If the the file extension is "graphmlz" or "gz", decompress it on the fly
fn from_file<P: AsRef<Path>>(path: P, compression: &str) -> Result<GraphML, Error> {
let extension = path.as_ref().extension().unwrap_or(OsStr::new(""));

let graph: Result<GraphML, Error> =
if extension.eq("graphmlz") || extension.eq("gz") || compression.eq("gzip") {
let reader = Self::open_file_gzip(path)?;
Self::read_graph_from_reader(reader)
} else {
let reader = Reader::from_file(path)?;
Self::read_graph_from_reader(reader)
};

graph
}
}

/// Read a list of graphs from a file in GraphML format.
Expand All @@ -703,9 +732,13 @@ impl GraphML {
/// :rtype: list[Union[PyGraph, PyDiGraph]]
/// :raises RuntimeError: when an error is encountered while parsing the GraphML file.
#[pyfunction]
#[pyo3(text_signature = "(path, /)")]
pub fn read_graphml(py: Python, path: &str) -> PyResult<Vec<PyObject>> {
let graphml = GraphML::from_file(path)?;
#[pyo3(signature=(path, compression=None),text_signature = "(path, /, compression=None)")]
pub fn read_graphml(
py: Python,
path: &str,
compression: Option<String>,
) -> PyResult<Vec<PyObject>> {
let graphml = GraphML::from_file(path, &compression.unwrap_or_default())?;

let mut out = Vec::new();
for graph in graphml.graphs {
Expand Down
55 changes: 53 additions & 2 deletions tests/test_graphml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import unittest
import tempfile
import gzip

import numpy

import rustworkx
Expand Down Expand Up @@ -55,8 +57,8 @@ def assertGraphMLRaises(self, graph_xml):
with self.assertRaises(Exception):
rustworkx.read_graphml(fd.name)

def test_simple(self):
graph_xml = self.HEADER.format(
def graphml_xml_example(self):
return self.HEADER.format(
"""
<key id="d0" for="node" attr.name="color" attr.type="string">
<default>yellow</default>
Expand All @@ -80,6 +82,8 @@ def test_simple(self):
"""
)

def test_simple(self):
graph_xml = self.graphml_xml_example()
with tempfile.NamedTemporaryFile("wt") as fd:
fd.write(graph_xml)
fd.flush()
Expand All @@ -96,6 +100,53 @@ def test_simple(self):
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_gzipped(self):
graph_xml = self.graphml_xml_example()

## Test reading a graphmlz
with tempfile.NamedTemporaryFile("w+b") as fd:
fd.flush()
newname = fd.name + ".gz"
with gzip.open(newname, "wt") as wf:
wf.write(graph_xml)

graphml = rustworkx.read_graphml(newname)
graph = graphml[0]
nodes = [
{"id": "n0", "color": "blue"},
{"id": "n1", "color": "yellow"},
{"id": "n2", "color": "green"},
]
edges = [
("n0", "n1", {"fidelity": 0.98}),
("n0", "n2", {"fidelity": 0.95}),
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_gzipped_force(self):
graph_xml = self.graphml_xml_example()

## Test reading a graphmlz
with tempfile.NamedTemporaryFile("w+b") as fd:
# close the file
fd.flush()
newname = fd.name + ".ext"
with gzip.open(newname, "wt") as wf:
wf.write(graph_xml)

graphml = rustworkx.read_graphml(newname, compression="gzip")
graph = graphml[0]
nodes = [
{"id": "n0", "color": "blue"},
{"id": "n1", "color": "yellow"},
{"id": "n2", "color": "green"},
]
edges = [
("n0", "n1", {"fidelity": 0.98}),
("n0", "n2", {"fidelity": 0.95}),
]
self.assertGraphEqual(graph, nodes, edges, directed=False)

def test_multiple_graphs_in_single_file(self):
graph_xml = self.HEADER.format(
"""
Expand Down

0 comments on commit 0017d00

Please sign in to comment.