Skip to content

Commit bac4cfd

Browse files
Alvantfacebook-github-bot
authored andcommitted
Fix mp serialization for integer nn.Parameter on CUDA (pytorch#56529)
Summary: Fixes pytorch#56342 Pull Request resolved: pytorch#56529 Reviewed By: albanD Differential Revision: D27896094 Pulled By: ngimel fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
1 parent febff45 commit bac4cfd

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

test/test_multiprocessing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -832,14 +832,31 @@ def test_cuda_parameter_sharing(self):
832832

833833
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
834834
don't support multiprocessing with spawn start method")
835-
def test_integer_parameter_serialization(self):
836-
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
835+
def test_integer_parameter_serialization_cpu(self):
836+
self._test_integer_parameter_serialization(device='cpu')
837+
838+
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
839+
don't support multiprocessing with spawn start method")
840+
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
841+
def test_integer_parameter_serialization_cuda(self):
842+
self._test_integer_parameter_serialization(device='cuda')
843+
844+
def _test_integer_parameter_serialization(self, device):
845+
param = torch.nn.Parameter(
846+
torch.tensor(0, dtype=torch.int64, device=device),
847+
requires_grad=False
848+
)
837849

838850
ctx = mp.get_context('spawn')
839-
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
851+
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
840852
p.start()
841853
p.join()
842854

855+
self.assertEqual(
856+
0, p.exitcode,
857+
msg=f'Failed to serialize successfully for "{device}" device!'
858+
)
859+
843860
def test_empty_shared(self):
844861
t = torch.tensor([])
845862
t.share_memory_()

torch/multiprocessing/reductions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,14 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
123123
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
124124

125125
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
126+
126127
if tensor_cls == torch.nn.parameter.Parameter:
127-
t = torch.nn.parameter.Parameter(t)
128-
t.requires_grad = requires_grad
128+
# It is crucial for integer tensors to receive
129+
# the requires_grad=False as an argument in the constructor
130+
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
131+
else:
132+
t.requires_grad = requires_grad
133+
129134
return t
130135

131136

0 commit comments

Comments
 (0)