Skip to content

Commit 14d3d29

Browse files
epwalshfacebook-github-bot
authored andcommitted
make ProcessException pickleable (pytorch#70118)
Summary: Fixes pytorch#70116 Happy to add tests if you let me know the best place to put them. cc VitalyFedyunin Pull Request resolved: pytorch#70118 Reviewed By: malfet Differential Revision: D33255899 Pulled By: ejguan fbshipit-source-id: 41d495374182eb28bb8bb421e890eca3bddc077b
1 parent 9c742be commit 14d3d29

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

test/test_multiprocessing_spawn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["module: multiprocessing"]
22

33
import os
4+
import pickle
45
import random
56
import signal
67
import sys
@@ -218,5 +219,15 @@ def test_process_exited(self):
218219
class ForkTest(TestCase, _TestMultiProcessing):
219220
start_method = 'fork'
220221

222+
223+
class ErrorTest(TestCase):
224+
def test_errors_pickleable(self):
225+
for error in (
226+
mp.ProcessRaisedException("Oh no!", 1, 1),
227+
mp.ProcessExitedException("Oh no!", 1, 1, 1),
228+
):
229+
pickle.loads(pickle.dumps(error))
230+
231+
221232
if __name__ == '__main__':
222233
run_tests()

torch/multiprocessing/spawn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ class ProcessException(Exception):
1414

1515
def __init__(self, msg: str, error_index: int, pid: int):
1616
super().__init__(msg)
17+
self.msg = msg
1718
self.error_index = error_index
1819
self.pid = pid
1920

21+
def __reduce__(self):
22+
return type(self), (self.msg, self.error_index, self.pid)
23+
2024

2125
class ProcessRaisedException(ProcessException):
2226
"""
@@ -47,6 +51,12 @@ def __init__(
4751
self.exit_code = exit_code
4852
self.signal_name = signal_name
4953

54+
def __reduce__(self):
55+
return (
56+
type(self),
57+
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
58+
)
59+
5060

5161
def _wrap(fn, i, args, error_queue):
5262
# prctl(2) is a Linux specific system call.

0 commit comments

Comments
 (0)