Skip to content

Commit d048893

Browse files
eellisonpytorchmergebot
authored andcommitted
Add Context Manager for Disabling Multithreading in Backwards, use in aot autograd (pytorch#86245)
We were running into a few issues with running multithreaded backwards in aot_autograd: such as pytorch#86136, and `FakeTensorMode` getting into a weird state as a result of not executing functions completely sequentially. The multithreaded backwards is lost in translation when we trace out the backwards anyway, and adds a lot of additional complexity. Pull Request resolved: pytorch#86245 Approved by: https://github.com/albanD, https://github.com/yf225
1 parent 237316a commit d048893

14 files changed

+126
-10
lines changed

aten/src/ATen/ThreadLocalState.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ void ThreadLocalState::set_grad_mode(bool enabled) {
2727
autograd_tls_.set_grad_mode(enabled);
2828
}
2929

30+
void ThreadLocalState::set_multithreading_enabled(bool enabled) {
31+
autograd_tls_.set_multithreading_enabled(enabled);
32+
}
33+
3034
/* static */
3135
void ThreadLocalState::setThreadLocalState(
3236
const ThreadLocalState& state) {

aten/src/ATen/ThreadLocalState.h

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class TORCH_API ThreadLocalState {
3030
// autograd engine.
3131
void set_grad_mode(bool enabled);
3232

33+
// set_multithreading_enabled - force the value of the multithreadinmaximum
34+
// threads TLS in
35+
// the current state object. This is used for example in the
36+
// autograd engine.
37+
void set_multithreading_enabled(bool enabled);
38+
3339
// Sets thread local variables in the current thread,
3440
// according to the thread boundary specified
3541
static void setThreadLocalState(const ThreadLocalState& state);

c10/core/AutogradState.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
namespace c10 {
44

55
namespace {
6-
// By default, grad mode is enabled and inference mode is disabled
6+
// By default, grad mode and mulithreading are enabled, inference mode is
7+
// disabled,
78
thread_local AutogradState autograd_state_tls = AutogradState(
89
/* grad_mode */ true,
910
/* inference_mode */ false,
10-
/* fw_grad_mode */ true);
11+
/* fw_grad_mode */ true,
12+
/* multithreading_enabled */ true);
1113
} // namespace
1214

1315
AutogradState& AutogradState::get_tls_state() {

c10/core/AutogradState.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@ struct C10_API AutogradState {
1212
static AutogradState& get_tls_state();
1313
static void set_tls_state(AutogradState state);
1414

15-
AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
15+
AutogradState(
16+
bool grad_mode,
17+
bool inference_mode,
18+
bool fw_grad_mode,
19+
bool multithreading_enabled)
1620
: grad_mode_(grad_mode),
1721
inference_mode_(inference_mode),
18-
fw_grad_mode_(fw_grad_mode) {}
22+
fw_grad_mode_(fw_grad_mode),
23+
mulithreading_enabled_(multithreading_enabled) {}
1924

2025
void set_grad_mode(bool enabled) {
2126
grad_mode_ = enabled;
@@ -29,6 +34,10 @@ struct C10_API AutogradState {
2934
inference_mode_ = enabled;
3035
}
3136

37+
void set_multithreading_enabled(bool mulithreading_enabled) {
38+
mulithreading_enabled_ = mulithreading_enabled;
39+
}
40+
3241
bool get_grad_mode() const {
3342
return grad_mode_;
3443
}
@@ -41,10 +50,15 @@ struct C10_API AutogradState {
4150
return inference_mode_;
4251
}
4352

53+
bool get_multithreading_enabled() const {
54+
return mulithreading_enabled_;
55+
}
56+
4457
private:
4558
bool grad_mode_ : 1;
4659
bool inference_mode_ : 1;
4760
bool fw_grad_mode_ : 1;
61+
bool mulithreading_enabled_ : 1;
4862
};
4963

5064
} // namespace c10

c10/core/InferenceMode.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ struct TORCH_API InferenceMode {
5858
AutogradState::set_tls_state(AutogradState(
5959
/* grad_mode */ !enabled,
6060
/* inference_mode */ enabled,
61-
/* fw_grad_mode */ !enabled));
61+
/* fw_grad_mode */ !enabled,
62+
/* multithreading_enabled*/ !enabled));
6263
DispatchKeySet included = enabled
6364
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
6465
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);

docs/source/torch.rst

+11
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,14 @@ Operator Tags
630630
.. This module needs to be documented. Adding here in the meantime
631631
.. for tracking purposes
632632
.. py:module:: torch.utils.model_dump
633+
634+
.. automodule:: torch.autograd
635+
.. currentmodule:: torch.autograd
636+
637+
Engine Configuration
638+
----------------------------------
639+
.. autosummary::
640+
:toctree: generated
641+
:nosignatures:
642+
643+
set_multithreading_enabled

functorch/_src/aot_autograd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def create_aot_dispatcher_function(
479479
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
480480
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
481481

482-
with preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
482+
with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode:
483483

484484
def process_inputs(flat_args):
485485
if config.use_fake_tensor:

test/allowlist_for_publicAPI.json

+1
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
"no_grad",
200200
"set_detect_anomaly",
201201
"set_grad_enabled",
202+
"set_multithreading_enabled",
202203
"variable"
203204
],
204205
"torch.autograd.function": [

test/test_autograd.py

+27
Original file line numberDiff line numberDiff line change
@@ -9304,6 +9304,33 @@ def foo(x):
93049304
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
93059305
foo(nt).backward(torch.nested.nested_tensor([torch.rand(1), torch.rand(1)], device=device))
93069306

9307+
@onlyCUDA
9308+
def test_backward_single_threaded(self):
9309+
9310+
threads_eq = None
9311+
9312+
class TestFn(Function):
9313+
@staticmethod
9314+
def forward(ctx, x, self):
9315+
ctx.self = self
9316+
ctx.tid = threading.get_ident()
9317+
return x.clone()
9318+
9319+
@staticmethod
9320+
def backward(ctx, gO):
9321+
nonlocal threads_eq
9322+
threads_eq = ctx.tid == threading.get_ident()
9323+
return gO, None
9324+
9325+
inp = torch.rand(10, device="cuda", requires_grad=True)
9326+
9327+
with torch.autograd.set_multithreading_enabled(False):
9328+
TestFn.apply(inp, None).sum().backward()
9329+
self.assertTrue(threads_eq)
9330+
9331+
TestFn.apply(inp, None).sum().backward()
9332+
self.assertFalse(threads_eq)
9333+
93079334
# Import test cases from below autograd/ here. These are found
93089335
# implicitly by the loader, so Flake8 thinks they are unused, hence
93099336
# the suppressions.

torch/_C/__init__.pyi.in

+3
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,9 @@ class _DisableFuncTorch:
948948
class _EnableTorchFunction:
949949
def __init__(self) -> None: ...
950950

951+
class _MultithreadingEnabled:
952+
def __init__(self, mode: _bool) -> None: ...
953+
951954
# Defined in torch/csrc/jit/python/script_init.cpp
952955
class LoggerBase(object):
953956
...

torch/autograd/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .variable import Variable
1616
from .function import Function, NestedIOFunction
1717
from .gradcheck import gradcheck, gradgradcheck
18-
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode
18+
from .grad_mode import no_grad, enable_grad, set_grad_enabled, inference_mode, set_multithreading_enabled
1919
from .anomaly_mode import detect_anomaly, set_detect_anomaly
2020
from ..overrides import has_torch_function, handle_torch_function, is_tensor_like
2121
from . import functional

torch/autograd/grad_mode.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Callable, TypeVar, cast
66

77
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
8-
'inference_mode']
8+
'inference_mode', 'set_multithreading_enabled']
99

1010

1111
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
@@ -184,7 +184,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
184184

185185

186186
class set_grad_enabled(_DecoratorContextManager):
187-
r"""Context-manager that sets gradient calculation to on or off.
187+
r"""Context-manager that sets gradient calculation on or off.
188188
189189
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
190190
It can be used as a context-manager or as a function.
@@ -298,3 +298,35 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
298298

299299
def clone(self):
300300
return self.__class__(self.mode)
301+
302+
303+
class set_multithreading_enabled(_DecoratorContextManager):
304+
r"""Context-manager that sets multithreaded backwards on or off.
305+
306+
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
307+
It can be used as a context-manager or as a function.
308+
309+
This context manager is thread local; it will not affect computation
310+
in other threads.
311+
312+
Args:
313+
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
314+
(``False``).
315+
316+
.. note::
317+
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
318+
319+
"""
320+
321+
def __init__(self, mode: bool) -> None:
322+
self.mode = mode
323+
self.multithreadeding_enabled_guard = torch._C._MultithreadingEnabled(mode)
324+
325+
def __enter__(self) -> None:
326+
pass
327+
328+
def __exit__(self, *args) -> None:
329+
del self.multithreadeding_enabled_guard
330+
331+
def clone(self):
332+
return self.__class__(self.mode)

torch/csrc/autograd/engine.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,9 @@ void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
12551255
auto Engine::ready_queue(
12561256
std::shared_ptr<ReadyQueue> cpu_ready_queue,
12571257
at::Device device) -> std::shared_ptr<ReadyQueue> {
1258-
if (should_run_in_cpu_ready_queue(device.type())) {
1258+
bool multithreading_disabled =
1259+
!c10::AutogradState::get_tls_state().get_multithreading_enabled();
1260+
if (multithreading_disabled || should_run_in_cpu_ready_queue(device.type())) {
12591261
// return the cpu ready queue passed in
12601262
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
12611263
return cpu_ready_queue;

torch/csrc/autograd/init.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ struct DisableFuncTorch {
4343
c10::impl::ExcludeDispatchKeyGuard back_guard_;
4444
};
4545

46+
struct MultithreadingEnabled {
47+
MultithreadingEnabled(bool enabled)
48+
: old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
49+
c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled);
50+
}
51+
~MultithreadingEnabled() {
52+
c10::AutogradState::get_tls_state().set_multithreading_enabled(old_);
53+
}
54+
bool old_;
55+
};
56+
4657
struct EnableTorchFunction {
4758
EnableTorchFunction()
4859
: old_(at::impl::PythonTorchFunctionTLS::is_disabled()) {
@@ -354,6 +365,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
354365
_C_m, "_DisablePythonDispatcher")
355366
.def(py::init<>());
356367
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
368+
py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled")
369+
.def(py::init<bool>());
357370

358371
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
359372
.def(py::init([]() -> torch::autograd::SavedVariable {

0 commit comments

Comments
 (0)