Skip to content

Commit f571ae4

Browse files
Revert "Make torch.device usable as a context manager (pytorch#91525)"
This reverts commit 619d52a. Reverted pytorch#91525 on behalf of https://github.com/mehtanirav due to Internal breakages
1 parent c73147f commit f571ae4

File tree

12 files changed

+90
-401
lines changed

12 files changed

+90
-401
lines changed

docs/source/tensor_attributes.rst

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,6 @@ Via a string and device ordinal:
177177
>>> torch.device('cpu', 0)
178178
device(type='cpu', index=0)
179179

180-
The device object can also be used as a context manager to change the default
181-
device tensors are allocated on:
182-
183-
::
184-
185-
>>> with torch.device('cuda:1'):
186-
... r = torch.randn(2, 3)
187-
>>> r.device
188-
device(type='cuda', index=1)
189-
190-
This context manager has no effect if a factory function is passed an explicit,
191-
non-None device argument. To globally change the default device, see also
192-
:func:`torch.set_default_device`.
193-
194180
.. note::
195181
The :class:`torch.device` argument in functions can generally be substituted with a string.
196182
This allows for fast prototyping of code.

docs/source/torch.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ Tensors
1717
is_nonzero
1818
set_default_dtype
1919
get_default_dtype
20-
set_default_device
2120
set_default_tensor_type
2221
numel
2322
set_printoptions

test/test_overrides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def __getattr__(self, name):
760760
val = getattr(self._data, name)
761761

762762
# If it's a method
763-
if not isinstance(val, torch.device) and callable(val):
763+
if callable(val):
764764
c = getattr(type(self._data), name)
765765
# Don't append self to args if classmethod/staticmethod
766766
if c is val:

test/test_utils.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,8 @@
1414
import torch.nn as nn
1515
import torch.utils.data
1616
from torch.utils.data import DataLoader
17-
from torch.testing._internal.common_device_type import (
18-
ops,
19-
onlyCPU,
20-
instantiate_device_type_tests,
21-
)
22-
from torch.testing._internal.common_methods_invocations import op_db
2317
import torch.cuda
24-
from torch.utils._pytree import tree_any, tree_all_only
2518
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
26-
from torch import set_default_device
27-
from torch.utils._device import set_device
2819
import torch.utils.cpp_extension
2920
from torch.autograd._functions.utils import check_onnx_broadcast
3021
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
@@ -805,74 +796,6 @@ def test_external_module_register(self):
805796
torch._register_device_module('xpu', DummyXPUModule)
806797

807798

808-
class TestDeviceUtils(TestCase):
809-
def test_basic(self):
810-
with torch.device('meta') as dev:
811-
x = torch.empty(3, 3)
812-
self.assertEqual(x.device.type, 'meta')
813-
self.assertEqual(dev, torch.device('meta'))
814-
815-
def test_decorator(self):
816-
@set_device('meta')
817-
def f():
818-
return torch.empty(3, 3)
819-
self.assertEqual(f().device.type, 'meta')
820-
821-
def test_decorator_generator(self):
822-
@set_device('meta')
823-
def f():
824-
yield torch.empty(3, 3)
825-
yield torch.empty(3, 3)
826-
r1, r2 = list(f())
827-
self.assertEqual(r1.device.type, 'meta')
828-
self.assertEqual(r2.device.type, 'meta')
829-
830-
831-
def test_nn_module(self):
832-
with torch.device('meta'):
833-
m = nn.Linear(40, 50)
834-
self.assertEqual(m.weight.device.type, 'meta')
835-
836-
def test_set_default_device(self):
837-
try:
838-
set_default_device('meta')
839-
r = torch.empty(2, 2)
840-
finally:
841-
set_default_device(None)
842-
843-
self.assertEqual(r.device.type, 'meta')
844-
845-
@onlyCPU
846-
@ops(op_db)
847-
def test_device_mode_ops(self, device, dtype, op):
848-
func = op.get_op()
849-
samples = op.sample_inputs(device, dtype, requires_grad=False)
850-
for sample in samples:
851-
# Only test samples which don't have Tensor inputs. However,
852-
# we don't test the factory property on OpInfo as it is very,
853-
# very incomplete
854-
if tree_any(
855-
lambda x: isinstance(x, torch.Tensor),
856-
(sample.input, sample.args, sample.kwargs)
857-
):
858-
continue
859-
# Many OpInfos will explicitly pass in a device. DeviceContext
860-
# will respect device if it is explicitly specified. To test
861-
# DeviceContext, we have to remove the device kwarg in this case.
862-
# NB: Can't pass None to sample_inputs, the function can't
863-
# handle it.
864-
kwargs = sample.kwargs.copy()
865-
kwargs.pop('device', None)
866-
with torch.device('meta'):
867-
r = func(sample.input, *sample.args, **kwargs)
868-
self.assertTrue(
869-
tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r)
870-
)
871-
872-
873-
instantiate_device_type_tests(TestDeviceUtils, globals())
874-
875-
876799
class TestCppExtensionUtils(TestCase):
877800
def test_cpp_compiler_is_ok(self):
878801
self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++'))

torch/_C/__init__.pyi.in

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,6 @@ class device:
4545
@overload
4646
def __init__(self, type: str, index: _int) -> None: ...
4747

48-
def __call__(self, func: T) -> T: ...
49-
50-
def __enter__(self) -> "device": ...
51-
52-
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
53-
5448
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
5549

5650
# Defined in torch/csrc/Stream.cpp

torch/__init__.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
__all__ = [
3737
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
38-
'set_default_device',
3938
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
4039
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
4140
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
@@ -445,49 +444,6 @@ def is_storage(obj):
445444
return type(obj) in _storage_classes
446445

447446

448-
_GLOBAL_DEVICE_CONTEXT = None
449-
450-
def set_default_device(device):
451-
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
452-
does not affect factory function calls which are called with an explicit
453-
``device`` argument. Factory calls will be performed as if they
454-
were passed ``device`` as an argument.
455-
456-
To only temporarily change the default device instead of setting it
457-
globally, use ``with torch.device(device):`` instead.
458-
459-
The default device is initially ``cpu``. If you set the default tensor
460-
device to another device (e.g., ``cuda``) without a device index, tensors
461-
will be allocated on whatever the current device for the device type,
462-
even after :func:`torch.cuda.set_device` is called.
463-
464-
Args:
465-
device (device or string): the device to set as default
466-
467-
Example::
468-
469-
>>> # xdoctest: +SKIP("requires cuda, changes global state")
470-
>>> torch.tensor([1.2, 3]).device
471-
device(type='cpu')
472-
>>> torch.set_default_device('cuda') # current device is 0
473-
>>> torch.tensor([1.2, 3]).device
474-
device(type='cuda', index=0)
475-
>>> torch.set_default_device('cuda:1')
476-
>>> torch.tensor([1.2, 3]).device
477-
device(type='cuda', index=1)
478-
479-
"""
480-
global _GLOBAL_DEVICE_CONTEXT
481-
if _GLOBAL_DEVICE_CONTEXT is not None:
482-
_GLOBAL_DEVICE_CONTEXT.__exit__(None, None, None)
483-
if device is None:
484-
_GLOBAL_DEVICE_CONTEXT = None
485-
return
486-
from torch.utils._device import DeviceContext
487-
_GLOBAL_DEVICE_CONTEXT = DeviceContext(device)
488-
_GLOBAL_DEVICE_CONTEXT.__enter__()
489-
490-
491447
def set_default_tensor_type(t):
492448
r"""Sets the default ``torch.Tensor`` type to floating point tensor type
493449
``t``. This type will also be used as default floating point type for

torch/autograd/grad_mode.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,96 @@
1+
import sys
12
import torch
2-
from typing import Any
3-
4-
from torch.utils._contextlib import _DecoratorContextManager
3+
import functools
4+
import inspect
5+
import warnings
6+
from typing import Any, Callable, TypeVar, cast
57

68
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
79
'inference_mode', 'set_multithreading_enabled']
810

11+
12+
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
13+
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
14+
FuncType = Callable[..., Any]
15+
F = TypeVar('F', bound=FuncType)
16+
17+
18+
class _DecoratorContextManager:
19+
"""Allow a context manager to be used as a decorator"""
20+
21+
def __call__(self, func: F) -> F:
22+
if inspect.isclass(func):
23+
warnings.warn("Decorating classes is deprecated and will be disabled in "
24+
"future versions. You should only decorate functions or methods. "
25+
"To preserve the current behavior of class decoration, you can "
26+
"directly decorate the `__init__` method and nothing else.")
27+
28+
if inspect.isgeneratorfunction(func):
29+
return self._wrap_generator(func)
30+
31+
@functools.wraps(func)
32+
def decorate_context(*args, **kwargs):
33+
with self.clone():
34+
return func(*args, **kwargs)
35+
return cast(F, decorate_context)
36+
37+
def _wrap_generator(self, func):
38+
"""Wrap each generator invocation with the context manager"""
39+
@functools.wraps(func)
40+
def generator_context(*args, **kwargs):
41+
gen = func(*args, **kwargs)
42+
43+
# Generators are suspended and unsuspended at `yield`, hence we
44+
# make sure the grad mode is properly set every time the execution
45+
# flow returns into the wrapped generator and restored when it
46+
# returns through our `yield` to our caller (see PR #49017).
47+
try:
48+
# Issuing `None` to a generator fires it up
49+
with self.clone():
50+
response = gen.send(None)
51+
52+
while True:
53+
try:
54+
# Forward the response to our caller and get its next request
55+
request = yield response
56+
57+
except GeneratorExit:
58+
# Inform the still active generator about its imminent closure
59+
with self.clone():
60+
gen.close()
61+
raise
62+
63+
except BaseException:
64+
# Propagate the exception thrown at us by the caller
65+
with self.clone():
66+
response = gen.throw(*sys.exc_info())
67+
68+
else:
69+
# Pass the last request to the generator and get its response
70+
with self.clone():
71+
response = gen.send(request)
72+
73+
# We let the exceptions raised above by the generator's `.throw` or
74+
# `.send` methods bubble up to our caller, except for StopIteration
75+
except StopIteration as e:
76+
# The generator informed us that it is done: take whatever its
77+
# returned value (if any) was and indicate that we're done too
78+
# by returning it (see docs for python's return-statement).
79+
return e.value
80+
81+
return generator_context
82+
83+
def __enter__(self) -> None:
84+
raise NotImplementedError
85+
86+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
87+
raise NotImplementedError
88+
89+
def clone(self):
90+
# override this method if your children class takes __init__ parameters
91+
return self.__class__()
92+
93+
994
class no_grad(_DecoratorContextManager):
1095
r"""Context-manager that disabled gradient calculation.
1196

torch/csrc/Device.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -169,36 +169,6 @@ PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
169169
END_HANDLE_TH_ERRORS
170170
}
171171

172-
PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
173-
HANDLE_TH_ERRORS
174-
py::object mode = py::module::import("torch.utils._device")
175-
.attr("DeviceContext")(py::handle(self));
176-
at::impl::PythonTorchFunctionTLS::push_onto_stack(
177-
std::make_shared<c10::SafePyObject>(
178-
mode.release().ptr(), getPyInterpreter()));
179-
// So that with torch.device('cuda') as dev: works
180-
Py_INCREF(self);
181-
return self;
182-
END_HANDLE_TH_ERRORS
183-
}
184-
185-
PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
186-
HANDLE_TH_ERRORS
187-
at::impl::PythonTorchFunctionTLS::pop_stack();
188-
Py_RETURN_NONE;
189-
END_HANDLE_TH_ERRORS
190-
}
191-
192-
PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
193-
HANDLE_TH_ERRORS
194-
py::object deco =
195-
py::module::import("torch.utils._device").attr("device_decorator");
196-
return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
197-
.release()
198-
.ptr();
199-
END_HANDLE_TH_ERRORS
200-
}
201-
202172
typedef PyObject* (*getter)(PyObject*, void*);
203173

204174
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
@@ -212,8 +182,6 @@ static struct PyGetSetDef THPDevice_properties[] = {
212182
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
213183
static PyMethodDef THPDevice_methods[] = {
214184
{"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
215-
{"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
216-
{"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
217185
{nullptr} /* Sentinel */
218186
};
219187

@@ -231,11 +199,6 @@ PyTypeObject THPDeviceType = {
231199
nullptr, /* tp_as_sequence */
232200
nullptr, /* tp_as_mapping */
233201
(hashfunc)THPDevice_hash, /* tp_hash */
234-
// TODO: We're not sure if this is a good idea or not, because making
235-
// torch.device callable means that it will start returning true
236-
// for callable() queries, and that is unexpected. We can always add
237-
// this later, so for now, don't actually implement this
238-
// THPDevice_call, /* tp_call */
239202
nullptr, /* tp_call */
240203
(reprfunc)THPDevice_str, /* tp_str */
241204
nullptr, /* tp_getattro */

torch/overrides.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def get_ignored_functions() -> Set[Callable]:
7575
torch.is_tensor,
7676
torch.is_storage,
7777
torch.set_default_tensor_type,
78-
torch.set_default_device,
7978
torch.set_rng_state,
8079
torch.get_rng_state,
8180
torch.manual_seed,

torch/signal/windows/windows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def kaiser(
378378
device=device,
379379
requires_grad=requires_grad)
380380

381-
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta, device=device))
381+
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(torch.tensor(beta))
382382

383383

384384
@_add_docstr(

0 commit comments

Comments
 (0)