|
5 | 5 | from typing import Any, Callable, TypeVar, cast
|
6 | 6 |
|
7 | 7 | __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
|
8 |
| - 'inference_mode'] |
| 8 | + 'inference_mode', 'set_multithreading_enabled'] |
9 | 9 |
|
10 | 10 |
|
11 | 11 | # Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
|
@@ -184,7 +184,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
184 | 184 |
|
185 | 185 |
|
186 | 186 | class set_grad_enabled(_DecoratorContextManager):
|
187 |
| - r"""Context-manager that sets gradient calculation to on or off. |
| 187 | + r"""Context-manager that sets gradient calculation on or off. |
188 | 188 |
|
189 | 189 | ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
|
190 | 190 | It can be used as a context-manager or as a function.
|
@@ -298,3 +298,35 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
298 | 298 |
|
299 | 299 | def clone(self):
|
300 | 300 | return self.__class__(self.mode)
|
| 301 | + |
| 302 | + |
| 303 | +class set_multithreading_enabled(_DecoratorContextManager): |
| 304 | + r"""Context-manager that sets multithreaded backwards on or off. |
| 305 | +
|
| 306 | + ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`. |
| 307 | + It can be used as a context-manager or as a function. |
| 308 | +
|
| 309 | + This context manager is thread local; it will not affect computation |
| 310 | + in other threads. |
| 311 | +
|
| 312 | + Args: |
| 313 | + mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable |
| 314 | + (``False``). |
| 315 | +
|
| 316 | + .. note:: |
| 317 | + This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| 318 | +
|
| 319 | + """ |
| 320 | + |
| 321 | + def __init__(self, mode: bool) -> None: |
| 322 | + self.mode = mode |
| 323 | + self.multithreadeding_enabled_guard = torch._C._MultithreadingEnabled(mode) |
| 324 | + |
| 325 | + def __enter__(self) -> None: |
| 326 | + pass |
| 327 | + |
| 328 | + def __exit__(self, *args) -> None: |
| 329 | + del self.multithreadeding_enabled_guard |
| 330 | + |
| 331 | + def clone(self): |
| 332 | + return self.__class__(self.mode) |
0 commit comments