Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[distGB] graphbolt graph edge's mask will be filled with 0 if these edges have no mask initial #7846

Merged
merged 15 commits into from
Jan 9, 2025
Merged
16 changes: 11 additions & 5 deletions python/dgl/distributed/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def collate(self, items):
raise NotImplementedError

@staticmethod
def add_edge_attribute_to_graph(g, data_name):
def add_edge_attribute_to_graph(g, data_name, gb_padding=0):
"""Add data into the graph as an edge attribute.

For some cases such as prob/mask-based sampling on GraphBolt partitions,
Expand All @@ -327,9 +327,11 @@ def add_edge_attribute_to_graph(g, data_name):
The graph.
data_name : str
The name of data that's stored in DistGraph.ndata/edata.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.
classicsong marked this conversation as resolved.
Show resolved Hide resolved
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)
g.add_edge_attribute(data_name, gb_padding)


class NodeCollator(Collator):
Expand All @@ -344,6 +346,8 @@ class NodeCollator(Collator):
The node set to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.

Examples
--------
Expand All @@ -366,7 +370,7 @@ class NodeCollator(Collator):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""

def __init__(self, g, nids, graph_sampler):
def __init__(self, g, nids, graph_sampler, gb_padding=0):
self.g = g
if not isinstance(nids, Mapping):
assert (
Expand All @@ -380,7 +384,7 @@ def __init__(self, g, nids, graph_sampler):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -612,6 +616,7 @@ def __init__(
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=0,
):
self.g = g
if not isinstance(eids, Mapping):
Expand Down Expand Up @@ -642,7 +647,7 @@ def __init__(
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)

@property
Expand Down Expand Up @@ -864,6 +869,7 @@ def __init__(self, g, eids, graph_sampler, device=None, **kwargs):
else:
dataloader_kwargs[k] = v

collator_kwargs["gb_padding"] = 1
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
if device is None:
# for the distributed case default to the CPU
device = "cpu"
Expand Down
15 changes: 9 additions & 6 deletions python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""

def __init__(self, name, kv_names):
def __init__(self, name, kv_names, padding=0):
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
self._name = name
self._kv_names = kv_names
self._padding = padding

def __getstate__(self):
return self._name, self._kv_names
return self._name, self._kv_names, self._padding

def __setstate__(self, state):
self._name, self._kv_names = state
self._name, self._kv_names, self._padding = state

def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
Expand All @@ -169,7 +170,7 @@ def process_request(self, server_state):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
classicsong marked this conversation as resolved.
Show resolved Hide resolved
Rhett-Ying marked this conversation as resolved.
Show resolved Hide resolved
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
Expand Down Expand Up @@ -1620,13 +1621,15 @@ def _get_edata_names(self, etype=None):
edata_names.append(name)
return edata_names

def add_edge_attribute(self, name):
def add_edge_attribute(self, name, padding=0):
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
"""Add an edge attribute into GraphBolt partition from edge data.

Parameters
----------
name : str
The name of the edge attribute.
padding : int, optional
The padding value for the new edge attribute.
"""
# Sanity checks.
if not self._use_graphbolt:
Expand All @@ -1643,7 +1646,7 @@ def add_edge_attribute(self, name):
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
)
# Wait for the response.
assert rpc.recv_response()._name == name
Expand Down
89 changes: 88 additions & 1 deletion tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import unittest
from pathlib import Path

import backend as F
import dgl

import dgl.backend as F
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1858,6 +1859,92 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
)


def check_mask_hetero_sampling_gb(tmpdir, num_server, use_graphbolt=True):
def create_hetero_graph(dense=False, empty=False):
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
num_nodes = {"n1": 210, "n2": 200, "n3": 220, "n4": 230}
etypes = [("n1", "r12", "n2"), ("n2", "r23", "n3"), ("n3", "r34", "n4")]
edges = {}
random.seed(42)
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(
num_nodes[src_ntype] - 10 if empty else num_nodes[src_ntype],
num_nodes[dst_ntype] - 10 if empty else num_nodes[dst_ntype],
density=0.1,
format="coo",
random_state=100,
)
edges[etype] = (arr.row, arr.col)
g = dgl.heterograph(edges, num_nodes)

return g

generate_ip_config("rpc_ip_config.txt", num_server, num_server)

g = create_hetero_graph()
eids = torch.randperm(g.num_edges("r34"))[:10]
mask = torch.zeros(g.num_edges("r34"), dtype=torch.bool)
mask[eids] = True

num_parts = num_server

orig_nid_map, orig_eid_map = partition_graph(
g,
"test_sampling",
num_parts,
tmpdir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
)

part_config = tmpdir / "test_sampling.json"

pserver_list = []
ctx = mp.get_context("spawn")
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
True,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)

dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True)
dist_graph = DistGraph("test_sampling", part_config=part_config)

os.environ["DGL_DIST_DEBUG"] = "1"

edges = {("n3", "r34", "n4"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
)

block = next(iter(loader))[2][0]
assert block.num_src_nodes("n1") > 0
CfromBU marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("num_parts", [1])
def test_local_masked_sampling_heterograph_gb(
num_server,
):
CfromBU marked this conversation as resolved.
Show resolved Hide resolved
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_mask_hetero_sampling_gb(Path(tmpdirname), num_server)


if __name__ == "__main__":
import tempfile

Expand Down
Loading