Skip to content

Commit

Permalink
equality support for all the node states
Browse files Browse the repository at this point in the history
  • Loading branch information
ljeub-pometry committed Jan 22, 2025
1 parent 633438e commit 8222b63
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
7 changes: 7 additions & 0 deletions raphtory/src/db/api/storage/graph/storage_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ impl std::fmt::Display for GraphStorage {
}

impl GraphStorage {
pub fn graph_id(&self) -> usize {
match self {
GraphStorage::Mem(g) => Arc::as_ptr(&g.graph).addr(),
GraphStorage::Unlocked(g) => Arc::as_ptr(g).addr(),
GraphStorage::Disk(g) => Arc::as_ptr(g).addr(),
}
}
#[inline(always)]
pub fn is_immutable(&self) -> bool {
match self {
Expand Down
12 changes: 12 additions & 0 deletions raphtory/src/db/graph/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph> + Debug> Debug
}
}

impl<'graph, G: GraphViewOps<'graph>, GH: GraphViewOps<'graph>> PartialEq for Nodes<'graph, G, GH> {
fn eq(&self, other: &Self) -> bool {
if self.base_graph.core_graph().graph_id() == other.base_graph.core_graph().graph_id() {
// same storage, can use internal ids
self.iter_refs().eq(other.iter_refs())
} else {
// different storage, use external ids
self.id().iter_values().eq(other.id().iter_values())
}
}
}

impl<'graph, G: IntoDynamic, GH: IntoDynamic> Nodes<'graph, G, GH> {
pub fn into_dyn(self) -> Nodes<'graph, DynamicGraph> {
Nodes {
Expand Down
45 changes: 23 additions & 22 deletions raphtory/src/python/graph/node_state/node_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ macro_rules! impl_node_state_ops {
self.inner.nodes()
}

fn __eq__<'py>(
&self,
other: &Bound<'py, PyAny>,
py: Python<'py>,
) -> Result<Bound<'py, PyAny>, std::convert::Infallible> {
let res = if let Ok(other) = other.downcast::<Self>() {
let other = Bound::borrow(other);
self.inner.iter_values().eq(other.inner.iter_values())
} else if let Ok(other) = other.extract::<Vec<$value>>() {
self.inner.iter_values().map($to_owned).eq(other.into_iter())
} else if let Ok(other) = other.extract::<HashMap<PyNodeRef, $value>>() {
(self.inner.len() == other.len()
&& other.into_iter().all(|(node, value)| {
self.inner.get_by_node(node).map($to_owned) == Some(value)
}))
} else {
return Ok(PyNotImplemented::get(py).to_owned().into_any());
};
Ok(res.into_pyobject(py)?.to_owned().into_any())
}

fn __iter__(&self) -> PyBorrowingIterator {
py_borrowing_iter!(self.inner.clone(), $inner_t, |inner| inner
.iter_values()
Expand All @@ -62,10 +83,10 @@ macro_rules! impl_node_state_ops {
///
/// Arguments:
/// node (NodeInput): the node
#[doc = concat!(" default (Optional[", $py_value, "]): the default value. Defaults to None.")]
#[doc = concat!(" default (Optional[", $py_value, "]): the default value. Defaults to None.")]
///
/// Returns:
#[doc = concat!(" Optional[", $py_value, "]: the value for the node or the default value")]
#[doc = concat!(" Optional[", $py_value, "]: the value for the node or the default value")]
#[pyo3(signature = (node, default=None::<$value>))]
fn get(&self, node: PyNodeRef, default: Option<$value>) -> Option<$value> {
self.inner.get_by_node(node).map($to_owned).or(default)
Expand Down Expand Up @@ -247,26 +268,6 @@ macro_rules! impl_node_state_ord_ops {
.map(|(n, v)| (n.cloned(), ($to_owned)(v)))
}

fn __eq__<'py>(
&self,
other: &Bound<'py, PyAny>,
py: Python<'py>,
) -> Result<Bound<'py, PyAny>, std::convert::Infallible> {
let res = if let Ok(other) = other.downcast::<Self>() {
let other = Bound::borrow(other);
self.inner.iter_values().eq(other.inner.iter_values())
} else if let Ok(other) = other.extract::<Vec<$value>>() {
self.inner.iter_values().map($to_owned).eq(other.into_iter())
} else if let Ok(other) = other.extract::<HashMap<PyNodeRef, $value>>() {
(self.inner.len() == other.len()
&& other.into_iter().all(|(node, value)| {
self.inner.get_by_node(node).map($to_owned) == Some(value)
}))
} else {
return Ok(PyNotImplemented::get(py).to_owned().into_any());
};
Ok(res.into_pyobject(py)?.to_owned().into_any())
}
}
};
}
Expand Down

0 comments on commit 8222b63

Please sign in to comment.