Skip to content

Commit

Permalink
Merge branch 'nicolasdugue-sinr-nodeembedding'
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaCappelletti94 committed Mar 5, 2024
2 parents 62ba66f + f317b2f commit 39944b3
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,9 @@ dmypy.json

# Pyre type checker
.pyre/

# VSCode project settings
.vscode/

# Mac OS X clutter
**/.DS_Store
44 changes: 44 additions & 0 deletions examples/structral_node_embedding/sinr_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""SINr illustrative example.
Nodes in both cliques (barbell graph) will get the same embedding vectors,
except for those connected to the path.
Nodes in the path are in distinct communities with a high-enough gamma,
and will thus get distinct vectors.
"""

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from sklearn.decomposition import PCA
from karateclub.node_embedding.structural import SINr


def embed_and_plot(graph: nx.Graph, gamma: int, ax: Axes):
"""Embed the graph using SINr and plot the 2D PCA projection.
Args:
graph (nx.Graph): The graph to embed.
gamma (int): The modularity multi-resolution parameter.
ax (Axes): The matplotlib axis to plot the graph on.
"""
model = SINr(gamma=gamma)
model.fit(graph)
embedding = model.get_embedding()

pca_embedding = PCA(n_components=2).fit_transform(embedding)

ax.scatter(pca_embedding[:, 0], pca_embedding[:, 1])
for idx, x in enumerate(pca_embedding):
ax.annotate(idx, (x[0], x[1]))


if __name__ == "__main__":

barbell = nx.barbell_graph(4, 8)
fig, axs = plt.subplots(3)

nx.draw_kamada_kawai(barbell, with_labels=True, ax=axs[0])

embed_and_plot(barbell, 0.5, axs[1])
embed_and_plot(barbell, 10, axs[2])

plt.show()
72 changes: 36 additions & 36 deletions karateclub/estimator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""General Estimator base class."""

import warnings
from typing import List
import re
import random
import numpy as np
import networkx as nx
import warnings
from typing import List
from tqdm.auto import trange
import re

"""General Estimator base class."""


class Estimator(object):
Expand All @@ -16,38 +16,41 @@ class Estimator(object):

def __init__(self):
"""Creating an estimator."""
pass

def fit(self):
"""Fitting a model."""
pass

def get_embedding(self):
"""Getting the embeddings (graph or node level)."""
pass

def get_memberships(self):
"""Getting the membership dictionary."""
pass

def get_cluster_centers(self):
"""Getting the cluster centers."""
pass


def get_params(self):
"""Get parameter dictionary for this estimator.."""
rx = re.compile(r'^\_')
rx = re.compile(r"^\_")
params = self.__dict__
params = {key: params[key] for key in params if not rx.search(key)}
return params

def set_params(self, **parameters):
"""Set the parameters of this estimator."""
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

def _set_seed(self):
"""Creating the initial random seed."""
random.seed(self.seed)
np.random.seed(self.seed)

@staticmethod
def _ensure_walk_traversal_conditions(graph: nx.classes.graph.Graph) -> nx.classes.graph.Graph:
def _ensure_walk_traversal_conditions(
graph: nx.classes.graph.Graph,
) -> nx.classes.graph.Graph:
"""Ensure walk traversal conditions."""
for node_index in trange(
graph.number_of_nodes(),
Expand All @@ -57,37 +60,34 @@ def _ensure_walk_traversal_conditions(graph: nx.classes.graph.Graph) -> nx.class
# for this process to take a bit of time.
disable=graph.number_of_nodes() < 10_000,
desc="Checking main diagonal existance",
dynamic_ncols=True
dynamic_ncols=True,
):
if not graph.has_edge(node_index, node_index):
warnings.warn(
(
"Please do be advised that "
"the graph you have provided does not "
"contain (some) edges in the main "
"diagonal, for instance the self-loop "
"constitued of ({}, {}). These selfloops "
"are necessary to ensure that the graph "
"is traversable, and for this reason we "
"create a copy of the graph and add therein "
"the missing edges. Since we are creating "
"a copy, this will immediately duplicate "
"the memory requirements. To avoid this double "
"allocation, you can provide the graph with the selfloops."
).format(
node_index,
node_index
)
"Please do be advised that "
"the graph you have provided does not "
"contain (some) edges in the main "
"diagonal, for instance the self-loop "
f"constitued of ({node_index}, {node_index}). These selfloops "
"are necessary to ensure that the graph "
"is traversable, and for this reason we "
"create a copy of the graph and add therein "
"the missing edges. Since we are creating "
"a copy, this will immediately duplicate "
"the memory requirements. To avoid this double "
"allocation, you can provide the graph with the selfloops."
)
# We create a copy of the graph
graph = graph.copy()
# And we add the missing edges
# for filling the main diagonal
graph.add_edges_from((
(index, index)
for index in range(graph.number_of_nodes())
if not graph.has_edge(index, index)
))
graph.add_edges_from(
(
(index, index)
for index in range(graph.number_of_nodes())
if not graph.has_edge(index, index)
)
)
break

return graph
Expand Down
2 changes: 2 additions & 0 deletions karateclub/node_embedding/structural/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
"""Submodule for the structural node embedding methods."""
from .graphwave import GraphWave
from .role2vec import Role2Vec
from .sinr import SINr
119 changes: 119 additions & 0 deletions karateclub/node_embedding/structural/sinr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Implementation of SINr: Fast Computing of Sparse Interpretable Node Representations."""

from typing import List, Set, Optional
import networkx as nx
from scipy.sparse import csr_matrix
from sklearn.preprocessing import normalize
import numpy as np
from karateclub.estimator import Estimator


class SINr(Estimator):
r"""An implementation of `"SINr" <https://inria.hal.science/hal-03197434/>`_
from the IDA '21 best paper "SINr: Fast Computing of Sparse Interpretable Node Representations is not a Sin!".
The procedure performs community detection using the Louvain algorithm, and computes the distribution of edges
of each node across all communities.
The algorithm is one of the fastest, because it mostly relies on Louvain community detection.
It thus runs in quasi-linear time. Regarding space complexity, the adjacency matrix and the community
membership matrix need to be stored, it is also quasi-linear.
Args:
gamma (int): modularity multi-resolution parameter. Default is 1.
The dimension parameter does not exist for SINr, gamma should be used instead:
the number of dimensions of the embedding space is based on the number of communities uncovered.
The higher gamma is, the more communities are detected, the higher the number of dimensions of
the latent space are uncovered. For small graphs, setting gamma to 1 is usually sufficient.
For bigger graphs, it is recommended to increase gamma (5 or 10 for example).
For word co-occurrence graphs, to deal with word embedding, gamma is usually set to 50 in order to get many small communities.
seed (int): Random seed value. Default is 42.
"""

def __init__(
self,
gamma: int = 1,
seed: int = 42,
):
self.gamma: int = gamma
self.seed: int = seed
self.number_of_nodes: Optional[int] = None
self.number_of_communities: Optional[int] = None
self._embedding: Optional[np.ndarray] = None

def fit(self, graph: nx.classes.graph.Graph):
"""
Fitting a SINr model.
Arg types:
* **graph** *(NetworkX graph)* - The graph to be embedded.
"""
self._set_seed()
graph = self._check_graph(graph)
# Get the adjacency matrix of the graph
adjacency = nx.adjacency_matrix(graph)
norm_adjacency = normalize(adjacency, "l1") # Make rows of matrix sum at 1
# Detect communities use louvain algorithm with the gamma resolution parameter
communities = nx.community.louvain_communities(
graph, resolution=self.gamma, seed=self.seed
)
self.number_of_nodes = graph.number_of_nodes()
self.number_of_communities = len(communities)
# Get the community membership of the graph
membership_matrix = self._get_matrix_membership(communities)
# Computes the node-recall: for each node, the distribution of links across communities
self._embedding = norm_adjacency.dot(membership_matrix)

def _get_matrix_membership(self, list_of_communities: List[Set[int]]):
r"""Getting the membership matrix describing for each node (rows), in which community (column) it belongs.
Return types:
* **Membership matrix** *(scipy sparse matrix csr)* - Size nodes, communities
"""
# Since we will have a lot of zeros, we use a sparse matrix.
# We build a CSR matrix.

# A CSR matrix is composite of two arrays: the data array and the indices array.
# The data array is a 1D array that contains all the non-zero values of the matrix.
nodes_per_community = np.empty(self.number_of_nodes, dtype=np.uint32)
# The indices array is a 1D array that contains the offsets of the start of each row of the matrix.
communities_comulative_degrees = np.empty(self.number_of_communities + 1, dtype=np.uint32)
offset: int = 0

# For each community, we store the nodes that belong to it.
for column_index, community in enumerate(list_of_communities):
# We store the offset of the start of each row of the matrix.
communities_comulative_degrees[column_index] = offset
# We store the nodes that belong to the community.
for node in community:
nodes_per_community[offset] = node
offset += 1

assert offset == self.number_of_nodes

# We set the offset of the end of the last row of the matrix
# to the number of nodes, which is expected to be identical
# to the offset of the start of the last row of the matrix.
communities_comulative_degrees[-1] = self.number_of_nodes

# And finally we can build the matrix.
return csr_matrix(
(
np.ones(self.number_of_nodes, dtype=np.float32),
nodes_per_community,
communities_comulative_degrees,
),
shape=(self.number_of_communities, self.number_of_nodes),
).T

def get_embedding(self) -> np.array:
r"""Getting the node embedding.
Return types:
* **embedding** *(Numpy array)* - The embedding of nodes.
"""
if self._embedding is None:
raise ValueError(
"No embedding has been computed. "
"Please call the fit method first."
)

return self._embedding.toarray()
35 changes: 34 additions & 1 deletion test/structral_node_embedding_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import networkx as nx
from karateclub import Role2Vec, GraphWave
from karateclub import Role2Vec, GraphWave, SINr


def test_role2vec():
Expand Down Expand Up @@ -73,3 +73,36 @@ def test_graphwave():
assert embedding.shape[0] == graph.number_of_nodes()
assert embedding.shape[1] == 2 * model.sample_number
assert type(embedding) == np.ndarray



def test_sinr():
"""
Testing the SINr class.
"""
model = SINr()

graph = nx.watts_strogatz_graph(100, 10, 0.5)

model.fit(graph)

embedding = model.get_embedding()

assert embedding.shape[0] == graph.number_of_nodes()
assert embedding.shape[1] == model.number_of_communities
assert isinstance(embedding, np.ndarray)

model = SINr(gamma=5)

graph = nx.watts_strogatz_graph(200, 10, 0.5)

model.fit(graph)

embedding = model.get_embedding()

assert embedding.shape[0] == graph.number_of_nodes()
assert embedding.shape[1] == model.number_of_communities
model2 = SINr(gamma=10)
model2.fit(graph)
assert model2.number_of_communities > model.number_of_communities
assert isinstance(embedding, np.ndarray)
13 changes: 12 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,15 @@ def test_get_params():
params = model.get_params()
assert len(params) != 0
assert type(params) is dict
assert '_embedding' not in params
assert '_embedding' not in params

def test_set_params():
model = DeepWalk()
default_params = model.get_params()
params = {'dimensions': 1,
'seed': 123}
model.set_params(**params)
new_params = model.get_params()
assert new_params != default_params
assert new_params['dimensions'] == 1
assert new_params['seed'] == 123

0 comments on commit 39944b3

Please sign in to comment.