You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered an issue while testing examples/distributed/pyg/node_ogb_cpu.py with four nodes and small batch sizes (e.g., 10). Using batch_size=1 triggers the error immediately. The script fails if all source nodes belong to the same partition when executing the following line in torch_geometric/distributed/dist_neighbor_sampler.py within the _get_sampler_output function:
Error message: AttributeError: 'NoneType' object has no attribute 'metadata'
I noticed that when the p_id is 0, but outputs[0] is None: [None, None, SamplerOutput(...), None]
It seems that p_id is being computed incorrectly in the following code segment:
if not local_only:
# Src nodes are remote
res_fut_list = await to_asyncio_future(
torch.futures.collect_all(futs))
for i, res_fut in enumerate(res_fut_list):
p_id = (self.graph_store.partition_idx + i +
1) % self.graph_store.num_partitions
p_outputs.pop(p_id)
p_outputs.insert(p_id, res_fut.wait())
🐛 Describe the bug
I encountered an issue while testing
examples/distributed/pyg/node_ogb_cpu.py
with four nodes and small batch sizes (e.g., 10). Usingbatch_size=1
triggers the error immediately. The script fails if all source nodes belong to the same partition when executing the following line intorch_geometric/distributed/dist_neighbor_sampler.py
within the_get_sampler_output
function:cumsum_neighbors_per_node = outputs[p_id].metadata[0]
Error message:
AttributeError: 'NoneType' object has no attribute 'metadata'
I noticed that when the
p_id
is 0, but outputs[0] is None:[None, None, SamplerOutput(...), None]
It seems that
p_id
is being computed incorrectly in the following code segment:Versions
Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] torch==2.4.1
[pip3] torch_cluster==1.6.3+pt24cu124
[pip3] torch-geometric==2.6.0
[pip3] torch_scatter==2.1.2+pt24cu124
[pip3] torch_sparse==0.6.18+pt24cu124
[pip3] torch_spline_conv==1.2.2+pt24cu124
[pip3] torchaudio==2.4.1
[pip3] torchvision==0.19.1
[pip3] triton==3.0.0
[conda] blas 1.0 mkl conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] numpy 2.0.2 pypi_0 pypi
[conda] pytorch 2.4.1 py3.12_cuda12.4_cudnn9.1.0_0 pytorch
[conda] pytorch-cuda 12.4 hc786d27_6 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch-cluster 1.6.3+pt24cu124 pypi_0 pypi
[conda] torch-geometric 2.6.0 pypi_0 pypi
[conda] torch-scatter 2.1.2+pt24cu124 pypi_0 pypi
[conda] torch-sparse 0.6.18+pt24cu124 pypi_0 pypi
[conda] torch-spline-conv 1.2.2+pt24cu124 pypi_0 pypi
[conda] torchaudio 2.4.1 py312_cu124 pytorch
[conda] torchtriton 3.0.0 py312 pytorch
[conda] torchvision 0.19.1 py312_cu124 pytorch
The text was updated successfully, but these errors were encountered: