|
14 | 14 | import torch.nn as nn
|
15 | 15 | import torch.utils.data
|
16 | 16 | from torch.utils.data import DataLoader
|
| 17 | +from torch.testing._internal.common_device_type import ( |
| 18 | + ops, |
| 19 | + onlyCPU, |
| 20 | + instantiate_device_type_tests, |
| 21 | +) |
| 22 | +from torch.testing._internal.common_methods_invocations import op_db |
17 | 23 | import torch.cuda
|
| 24 | +from torch.utils._pytree import tree_any, tree_all_only |
18 | 25 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
| 26 | +from torch import set_default_device |
| 27 | +from torch.utils._device import set_device |
19 | 28 | import torch.utils.cpp_extension
|
20 | 29 | from torch.autograd._functions.utils import check_onnx_broadcast
|
21 | 30 | from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
@@ -796,6 +805,74 @@ def test_external_module_register(self):
|
796 | 805 | torch._register_device_module('xpu', DummyXPUModule)
|
797 | 806 |
|
798 | 807 |
|
| 808 | +class TestDeviceUtils(TestCase): |
| 809 | + def test_basic(self): |
| 810 | + with torch.device('meta') as dev: |
| 811 | + x = torch.empty(3, 3) |
| 812 | + self.assertEqual(x.device.type, 'meta') |
| 813 | + self.assertEqual(dev, torch.device('meta')) |
| 814 | + |
| 815 | + def test_decorator(self): |
| 816 | + @set_device('meta') |
| 817 | + def f(): |
| 818 | + return torch.empty(3, 3) |
| 819 | + self.assertEqual(f().device.type, 'meta') |
| 820 | + |
| 821 | + def test_decorator_generator(self): |
| 822 | + @set_device('meta') |
| 823 | + def f(): |
| 824 | + yield torch.empty(3, 3) |
| 825 | + yield torch.empty(3, 3) |
| 826 | + r1, r2 = list(f()) |
| 827 | + self.assertEqual(r1.device.type, 'meta') |
| 828 | + self.assertEqual(r2.device.type, 'meta') |
| 829 | + |
| 830 | + |
| 831 | + def test_nn_module(self): |
| 832 | + with torch.device('meta'): |
| 833 | + m = nn.Linear(40, 50) |
| 834 | + self.assertEqual(m.weight.device.type, 'meta') |
| 835 | + |
| 836 | + def test_set_default_device(self): |
| 837 | + try: |
| 838 | + set_default_device('meta') |
| 839 | + r = torch.empty(2, 2) |
| 840 | + finally: |
| 841 | + set_default_device(None) |
| 842 | + |
| 843 | + self.assertEqual(r.device.type, 'meta') |
| 844 | + |
| 845 | + @onlyCPU |
| 846 | + @ops(op_db) |
| 847 | + def test_device_mode_ops(self, device, dtype, op): |
| 848 | + func = op.get_op() |
| 849 | + samples = op.sample_inputs(device, dtype, requires_grad=False) |
| 850 | + for sample in samples: |
| 851 | + # Only test samples which don't have Tensor inputs. However, |
| 852 | + # we don't test the factory property on OpInfo as it is very, |
| 853 | + # very incomplete |
| 854 | + if tree_any( |
| 855 | + lambda x: isinstance(x, torch.Tensor), |
| 856 | + (sample.input, sample.args, sample.kwargs) |
| 857 | + ): |
| 858 | + continue |
| 859 | + # Many OpInfos will explicitly pass in a device. DeviceContext |
| 860 | + # will respect device if it is explicitly specified. To test |
| 861 | + # DeviceContext, we have to remove the device kwarg in this case. |
| 862 | + # NB: Can't pass None to sample_inputs, the function can't |
| 863 | + # handle it. |
| 864 | + kwargs = sample.kwargs.copy() |
| 865 | + kwargs.pop('device', None) |
| 866 | + with torch.device('meta'): |
| 867 | + r = func(sample.input, *sample.args, **kwargs) |
| 868 | + self.assertTrue( |
| 869 | + tree_all_only(torch.Tensor, lambda x: x.device.type == 'meta', r) |
| 870 | + ) |
| 871 | + |
| 872 | + |
| 873 | +instantiate_device_type_tests(TestDeviceUtils, globals()) |
| 874 | + |
| 875 | + |
799 | 876 | class TestCppExtensionUtils(TestCase):
|
800 | 877 | def test_cpp_compiler_is_ok(self):
|
801 | 878 | self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform('c++'))
|
|
0 commit comments