Skip to content

Commit 81b5eff

Browse files
ezyangpytorchmergebot
authored andcommitted
Reland "Add torch.utils.device_mode" (pytorch#91796)
Original PR pytorch#91525 Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch#91796 Approved by: https://github.com/albanD
1 parent eeb3e49 commit 81b5eff

13 files changed

+422
-94
lines changed

docs/source/tensor_attributes.rst

+14
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,20 @@ Via a string and device ordinal:
177177
>>> torch.device('cpu', 0)
178178
device(type='cpu', index=0)
179179

180+
The device object can also be used as a context manager to change the default
181+
device tensors are allocated on:
182+
183+
::
184+
185+
>>> with torch.device('cuda:1'):
186+
... r = torch.randn(2, 3)
187+
>>> r.device
188+
device(type='cuda', index=1)
189+
190+
This context manager has no effect if a factory function is passed an explicit,
191+
non-None device argument. To globally change the default device, see also
192+
:func:`torch.set_default_device`.
193+
180194
.. note::
181195
The :class:`torch.device` argument in functions can generally be substituted with a string.
182196
This allows for fast prototyping of code.

docs/source/torch.rst

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Tensors
1717
is_nonzero
1818
set_default_dtype
1919
get_default_dtype
20+
set_default_device
2021
set_default_tensor_type
2122
numel
2223
set_printoptions

test/test_autograd.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,36 @@ def test_grad_mode_class_decoration(self):
122122
with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"):
123123
@torch.no_grad()
124124
class Foo():
125-
pass
125+
def __init__(self):
126+
assert not torch.is_grad_enabled()
127+
128+
def foo(self):
129+
# Not applied to methods
130+
assert torch.is_grad_enabled()
131+
132+
# Show that we can actually construct the class
133+
foo = Foo()
134+
foo.foo()
126135

127136
# Decorating functions or methods is fine though
128137
with warnings.catch_warnings(record=True) as w:
129138
@torch.no_grad()
130139
def foo():
131-
pass
140+
assert not torch.is_grad_enabled()
141+
142+
foo()
132143

133144
class Foo2():
134145
@torch.no_grad()
135146
def __init__(self):
136-
pass
147+
assert not torch.is_grad_enabled()
137148

138149
@torch.no_grad()
139150
def foo(self):
140-
pass
151+
assert not torch.is_grad_enabled()
152+
153+
foo2 = Foo2()
154+
foo2.foo()
141155

142156
self.assertEqual(len(w), 0)
143157

test/test_overrides.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def __getattr__(self, name):
760760
val = getattr(self._data, name)
761761

762762
# If it's a method
763-
if callable(val):
763+
if not isinstance(val, torch.device) and callable(val):
764764
c = getattr(type(self._data), name)
765765
# Don't append self to args if classmethod/staticmethod
766766
if c is val:

test/test_utils.py

+77
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@
1414
import torch.nn as nn
1515
import torch.utils.data
1616
from torch.utils.data import DataLoader
17+
from torch.testing._internal.common_device_type import (
18+
ops,
19+
onlyCPU,
20+
instantiate_device_type_tests,
21+
)
22+
from torch.testing._internal.common_methods_invocations import op_db
1723
import torch.cuda
24+
from torch.utils._pytree import tree_any, tree_all_only
1825
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
26+
from torch import set_default_device
27+
from torch.utils._device import set_device
1928
import torch.utils.cpp_extension
2029
from torch.autograd._functions.utils import check_onnx_broadcast
2130
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
@@ -796,6 +805,74 @@ def test_external_module_register(self):
796805
torch._register_device_module('xpu', DummyXPUModule)
797806

798807

808+
class TestDeviceUtils(TestCase):
809+
def test_basic(self):
810+
with torch.device('meta') as dev:
811+
x = torch.empty(3, 3)
812+
self.assertEqual(x.device.type, 'meta')
813+
self.assertEqual(dev, torch.device('meta'))
814+
815+
def test_decorator(self):
816+
@set_device('meta')
817+
def f():
818+
return torch.empty(3, 3)
819+
self.assertEqual(f().device.type, 'meta')
820+
821+
def test_decorator_generator(self):
822+
@set_device('meta')
823+
def f():
824+
yield torch.empty(3, 3)
825+
yield torch.empty(3, 3)
826+
r1, r2 = list(f())
827+
self.assertEqual(r1.device.type, 'meta')
828+
self.assertEqual(r2.device.type, 'meta')
829+
830+
831+
def test_nn_module(self):
832+
with torch.device('meta'):
833+
m = nn.Linear(40, 50)
834+
self.assertEqual(m.weight.device.type, 'meta')
835+
836+
def test_set_default_device(self):
837+
try:
838+
set_default_device('meta')
839+
r = torch.empty(2, 2)
840+
finally:
841+
set_default_device(None)
842+
843+
self.assertEqual(r.device.type, 'meta')
844+
845+
@onlyCPU
846+
@ops(op_db)
847+
def test_device_mode_ops(self, device, dtype, op):
848+
func = op.get_op()
849+
samples = op.sample_inputs(device, dtype, requires_grad=False)
850+
for sample in samples:
851+
# Only test samples which don't have Tensor inputs. However,
852+
# we don't test the factory property on OpInfo as it is very,
853+
# very incomplete
854+
if tree_any(
855+
lambda x: isinstance(x, torch.Tensor),
856+
(sample.input, sample.args, sample.kwargs)
857+
):
858+
continue
859+
# Many OpInfos will explicitly pass in a device. DeviceContext
860+
# will respect device if it is explicitly specified. To test
861+
# DeviceContext, we have to remove the device kwarg in this case.
862+
# NB: Can't pass None to sample_inputs, the function can't
863+
# handle it.
864+
kwargs = sample.kwargs.copy()
865+
kwargs.pop('device', None)
866+
with torch.device('meta'):
867+
r = func(sample.input, *sample.args, **kwargs)
868+
self.assertTrue(
869+
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
870+
)
871+
872+
873+
instantiate_device_type_tests(TestDeviceUtils, globals())
874+
875+
799876
class TestCppExtensionUtils(TestCase):
800877
def test_cpp_compiler_is_ok(self):
801878
self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++'))

torch/_C/__init__.pyi.in

+7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ class device:
4545
@overload
4646
def __init__(self, type: str, index: _int) -> None: ...
4747

48+
# Uncomment if we ever make torch.device a decorator
49+
# def __call__(self, func: T) -> T: ...
50+
51+
def __enter__(self) -> "device": ...
52+
53+
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
54+
4855
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
4956

5057
# Defined in torch/csrc/Stream.cpp

torch/__init__.py

+44
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
__all__ = [
3737
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
38+
'set_default_device',
3839
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
3940
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
4041
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
@@ -444,6 +445,49 @@ def is_storage(obj):
444445
return type(obj) in _storage_classes
445446

446447

448+
_GLOBAL_DEVICE_CONTEXT = None
449+
450+
def set_default_device(device):
451+
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
452+
does not affect factory function calls which are called with an explicit
453+
``device`` argument. Factory calls will be performed as if they
454+
were passed ``device`` as an argument.
455+
456+
To only temporarily change the default device instead of setting it
457+
globally, use ``with torch.device(device):`` instead.
458+
459+
The default device is initially ``cpu``. If you set the default tensor
460+
device to another device (e.g., ``cuda``) without a device index, tensors
461+
will be allocated on whatever the current device for the device type,
462+
even after :func:`torch.cuda.set_device` is called.
463+
464+
Args:
465+
device (device or string): the device to set as default
466+
467+
Example::
468+
469+
>>> # xdoctest: +SKIP("requires cuda, changes global state")
470+
>>> torch.tensor([1.2, 3]).device
471+
device(type='cpu')
472+
>>> torch.set_default_device('cuda') # current device is 0
473+
>>> torch.tensor([1.2, 3]).device
474+
device(type='cuda', index=0)
475+
>>> torch.set_default_device('cuda:1')
476+
>>> torch.tensor([1.2, 3]).device
477+
device(type='cuda', index=1)
478+
479+
"""
480+
global _GLOBAL_DEVICE_CONTEXT
481+
if _GLOBAL_DEVICE_CONTEXT is not None:
482+
_GLOBAL_DEVICE_CONTEXT.__exit__(None, None, None)
483+
if device is None:
484+
_GLOBAL_DEVICE_CONTEXT = None
485+
return
486+
from torch.utils._device import DeviceContext
487+
_GLOBAL_DEVICE_CONTEXT = DeviceContext(device)
488+
_GLOBAL_DEVICE_CONTEXT.__enter__()
489+
490+
447491
def set_default_tensor_type(t):
448492
r"""Sets the default ``torch.Tensor`` type to floating point tensor type
449493
``t``. This type will also be used as default floating point type for

torch/autograd/grad_mode.py

+3-88
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,11 @@
1-
import sys
21
import torch
3-
import functools
4-
import inspect
5-
import warnings
6-
from typing import Any, Callable, TypeVar, cast
2+
from typing import Any
3+
4+
from torch.utils._contextlib import _DecoratorContextManager
75

86
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
97
'inference_mode', 'set_multithreading_enabled']
108

11-
12-
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
13-
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
14-
FuncType = Callable[..., Any]
15-
F = TypeVar('F', bound=FuncType)
16-
17-
18-
class _DecoratorContextManager:
19-
"""Allow a context manager to be used as a decorator"""
20-
21-
def __call__(self, func: F) -> F:
22-
if inspect.isclass(func):
23-
warnings.warn("Decorating classes is deprecated and will be disabled in "
24-
"future versions. You should only decorate functions or methods. "
25-
"To preserve the current behavior of class decoration, you can "
26-
"directly decorate the `__init__` method and nothing else.")
27-
28-
if inspect.isgeneratorfunction(func):
29-
return self._wrap_generator(func)
30-
31-
@functools.wraps(func)
32-
def decorate_context(*args, **kwargs):
33-
with self.clone():
34-
return func(*args, **kwargs)
35-
return cast(F, decorate_context)
36-
37-
def _wrap_generator(self, func):
38-
"""Wrap each generator invocation with the context manager"""
39-
@functools.wraps(func)
40-
def generator_context(*args, **kwargs):
41-
gen = func(*args, **kwargs)
42-
43-
# Generators are suspended and unsuspended at `yield`, hence we
44-
# make sure the grad mode is properly set every time the execution
45-
# flow returns into the wrapped generator and restored when it
46-
# returns through our `yield` to our caller (see PR #49017).
47-
try:
48-
# Issuing `None` to a generator fires it up
49-
with self.clone():
50-
response = gen.send(None)
51-
52-
while True:
53-
try:
54-
# Forward the response to our caller and get its next request
55-
request = yield response
56-
57-
except GeneratorExit:
58-
# Inform the still active generator about its imminent closure
59-
with self.clone():
60-
gen.close()
61-
raise
62-
63-
except BaseException:
64-
# Propagate the exception thrown at us by the caller
65-
with self.clone():
66-
response = gen.throw(*sys.exc_info())
67-
68-
else:
69-
# Pass the last request to the generator and get its response
70-
with self.clone():
71-
response = gen.send(request)
72-
73-
# We let the exceptions raised above by the generator's `.throw` or
74-
# `.send` methods bubble up to our caller, except for StopIteration
75-
except StopIteration as e:
76-
# The generator informed us that it is done: take whatever its
77-
# returned value (if any) was and indicate that we're done too
78-
# by returning it (see docs for python's return-statement).
79-
return e.value
80-
81-
return generator_context
82-
83-
def __enter__(self) -> None:
84-
raise NotImplementedError
85-
86-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
87-
raise NotImplementedError
88-
89-
def clone(self):
90-
# override this method if your children class takes __init__ parameters
91-
return self.__class__()
92-
93-
949
class no_grad(_DecoratorContextManager):
9510
r"""Context-manager that disabled gradient calculation.
9611

0 commit comments

Comments
 (0)