Skip to content

Commit

Permalink
[distGB] graphbolt graph edge's mask will be filled with 0 if these e…
Browse files Browse the repository at this point in the history
…dges have no mask initial (#7846)
  • Loading branch information
CfromBU authored Jan 9, 2025
1 parent 540dd2b commit 17017c2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 14 deletions.
25 changes: 18 additions & 7 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def collate(self, items):
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name):
def add_edge_attribute_to_graph(g, data_name, gb_padding):
"""Add data into the graph as an edge attribute.
For some cases such as prob/mask-based sampling on GraphBolt partitions,
Expand All @@ -327,9 +327,11 @@ def add_edge_attribute_to_graph(g, data_name):
The graph.
data_name : str
The name of data that's stored in DistGraph.ndata/edata.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)
g.add_edge_attribute(data_name, gb_padding)


class NodeCollator(Collator):
Expand All @@ -344,6 +346,11 @@ class NodeCollator(Collator):
The node set to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
Examples
--------
Expand All @@ -366,7 +373,7 @@ class NodeCollator(Collator):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""

def __init__(self, g, nids, graph_sampler):
def __init__(self, g, nids, graph_sampler, gb_padding=1):
self.g = g
if not isinstance(nids, Mapping):
assert (
Expand All @@ -380,7 +387,7 @@ def __init__(self, g, nids, graph_sampler):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -508,8 +515,11 @@ class EdgeCollator(Collator):
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
Examples
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
Expand Down Expand Up @@ -612,6 +622,7 @@ def __init__(
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=1,
):
self.g = g
if not isinstance(eids, Mapping):
Expand Down Expand Up @@ -642,7 +653,7 @@ def __init__(
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down
21 changes: 15 additions & 6 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names):
def __init__(self, name, kv_names, padding):
self._name = name
self._kv_names = kv_names
self._padding = padding

def __getstate__(self):
return self._name, self._kv_names
return self._name, self._kv_names, self._padding

def __setstate__(self, state):
self._name, self._kv_names = state
self._name, self._kv_names, self._padding = state

def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
Expand All @@ -169,7 +170,13 @@ def process_request(self, server_state):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)

# Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.
# In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges.
# In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes.
# To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1),
# allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph.
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
Expand Down Expand Up @@ -1620,13 +1627,15 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name):
def add_edge_attribute(self, name, padding):
"""Add an edge attribute into GraphBolt partition from edge data.
Parameters
----------
name : str
The name of the edge attribute.
padding : int, optional
The padding value for the new edge attribute.
"""
# Sanity checks.
if not self._use_graphbolt:
Expand All @@ -1643,7 +1652,7 @@ def add_edge_attribute(self, name):
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
)
# Wait for the response.
assert rpc.recv_response()._name == name
Expand Down
78 changes: 77 additions & 1 deletion tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import unittest
from pathlib import Path

import backend as F
import dgl

import dgl.backend as F
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1858,6 +1859,81 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
)


def check_hetero_dist_edge_dataloader_gb(
tmpdir, num_server, use_graphbolt=True
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)

g = create_random_hetero()
eids = torch.randperm(g.num_edges("r23"))[:10]
mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool)
mask[eids] = True

num_parts = num_server

orig_nid_map, orig_eid_map = partition_graph(
g,
"test_sampling",
num_parts,
tmpdir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
)

part_config = tmpdir / "test_sampling.json"

pserver_list = []
ctx = mp.get_context("spawn")
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
True,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)

dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True)
dist_graph = DistGraph("test_sampling", part_config=part_config)

os.environ["DGL_DIST_DEBUG"] = "1"

edges = {("n2", "r23", "n3"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
)
dgl.distributed.exit_client()
for p in pserver_list:
p.join()
assert p.exitcode == 0

block = next(iter(loader))[2][0]
assert block.num_src_nodes("n1") > 0
assert block.num_edges("r12") > 0
assert block.num_edges("r13") > 0
assert block.num_edges("r23") > 0


def test_hetero_dist_edge_dataloader_gb(
num_server=1,
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server)


if __name__ == "__main__":
import tempfile

Expand Down

0 comments on commit 17017c2

Please sign in to comment.