Skip to content

Commit 3ffd2af

Browse files
aivanoufacebook-github-bot
authored andcommitted
Add exception classification to torch.multiprocessing.spawn (pytorch#45174)
Summary: Pull Request resolved: pytorch#45174 Introduce different types of exceptions that map to different failures of torch.multiprocessing.spawn. The change introduces three different exception types: ProcessRaisedException - occurs when the process initiated by spawn raises an exception ProcessExitedException - occurs when the process initiated by spawn exits The following logic will allow frameworks that use mp.spawn to categorize failures. This can be helpful for tracking metrics and enhancing logs. Test Plan: Imported from OSS Reviewed By: taohe Differential Revision: D23889400 Pulled By: tierex fbshipit-source-id: 8849624c616230a6a81158c52ce0c18beb437330
1 parent da033e0 commit 3ffd2af

File tree

3 files changed

+81
-6
lines changed

3 files changed

+81
-6
lines changed

test/test_multiprocessing_spawn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def test_nested_child_body(i, ready_queue, nested_child_sleep):
5454
time.sleep(nested_child_sleep)
5555

5656

57+
def test_infinite_task(i):
58+
while True:
59+
time.sleep(1)
60+
61+
62+
def test_process_exit(idx):
63+
sys.exit(12)
64+
65+
5766
def test_nested(i, pids_queue, nested_child_sleep, start_method):
5867
context = mp.get_context(start_method)
5968
nested_child_ready_queue = context.Queue()
@@ -184,6 +193,23 @@ def test_nested(self):
184193
class SpawnTest(TestCase, _TestMultiProcessing):
185194
start_method = 'spawn'
186195

196+
def test_exception_raises(self):
197+
with self.assertRaises(mp.ProcessRaisedException):
198+
mp.spawn(test_success_first_then_exception_func, args=(), nprocs=1)
199+
200+
def test_signal_raises(self):
201+
context = mp.spawn(test_infinite_task, args=(), nprocs=1, join=False)
202+
for pid in context.pids():
203+
os.kill(pid, signal.SIGTERM)
204+
with self.assertRaises(mp.ProcessExitedException):
205+
context.join()
206+
207+
def test_process_exited(self):
208+
with self.assertRaises(mp.ProcessExitedException) as e:
209+
mp.spawn(test_process_exit, args=(), nprocs=1)
210+
self.assertEqual(12, e.exit_code)
211+
212+
187213
@unittest.skipIf(
188214
IS_WINDOWS,
189215
"Fork is only available on Unix",

torch/multiprocessing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242

4343
"""Add helper function to spawn N processes and wait for completion of any of
4444
them. This depends `mp.get_context` which was added in Python 3.4."""
45-
from .spawn import spawn, SpawnContext, _supports_context, start_processes, ProcessContext
45+
from .spawn import spawn, SpawnContext, _supports_context, start_processes, ProcessContext, \
46+
ProcessRaisedException, ProcessExitedException
4647

4748

4849
if sys.platform == 'darwin' or sys.platform == 'win32':

torch/multiprocessing/spawn.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
from typing import Optional
23
import multiprocessing
34
import multiprocessing.connection
45
import signal
@@ -8,6 +9,45 @@
89
from . import _prctl_pr_set_pdeathsig
910

1011

12+
class ProcessException(Exception):
13+
__slots__ = ["error_index", "error_pid"]
14+
15+
def __init__(self, msg: str, error_index: int, pid: int):
16+
super().__init__(msg)
17+
self.error_index = error_index
18+
self.pid = pid
19+
20+
21+
class ProcessRaisedException(ProcessException):
22+
"""
23+
Exception is thrown when the process failed due to exception
24+
raised by the code.
25+
"""
26+
def __init__(
27+
self,
28+
msg: str,
29+
error_index: int,
30+
error_pid: int,
31+
):
32+
super().__init__(msg, error_index, error_pid)
33+
34+
35+
class ProcessExitedException(ProcessException):
36+
"""
37+
Exception is thrown when the process failed due to signal
38+
or exited with a specific code.
39+
"""
40+
__slots__ = ["exit_code"]
41+
42+
def __init__(
43+
self, msg: str, error_index: int, error_pid: int,
44+
exit_code: int, signal_name: Optional[str] = None
45+
):
46+
super().__init__(msg, error_index, error_pid)
47+
self.exit_code = exit_code
48+
self.signal_name = signal_name
49+
50+
1151
def _wrap(fn, i, args, error_queue):
1252
# prctl(2) is a Linux specific system call.
1353
# On other systems the following function call has no effect.
@@ -98,24 +138,32 @@ def join(self, timeout=None):
98138
process.join()
99139

100140
# There won't be an error on the queue if the process crashed.
141+
failed_process = self.processes[error_index]
101142
if self.error_queues[error_index].empty():
102143
exitcode = self.processes[error_index].exitcode
103144
if exitcode < 0:
104145
name = signal.Signals(-exitcode).name
105-
raise Exception(
146+
raise ProcessExitedException(
106147
"process %d terminated with signal %s" %
107-
(error_index, name)
148+
(error_index, name),
149+
error_index=error_index,
150+
error_pid=failed_process.pid,
151+
exit_code=exitcode,
152+
signal_name=name
108153
)
109154
else:
110-
raise Exception(
155+
raise ProcessExitedException(
111156
"process %d terminated with exit code %d" %
112-
(error_index, exitcode)
157+
(error_index, exitcode),
158+
error_index=error_index,
159+
error_pid=failed_process.pid,
160+
exit_code=exitcode
113161
)
114162

115163
original_trace = self.error_queues[error_index].get()
116164
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
117165
msg += original_trace
118-
raise Exception(msg)
166+
raise ProcessRaisedException(msg, error_index, failed_process.pid)
119167

120168

121169
class SpawnContext(ProcessContext):

0 commit comments

Comments
 (0)