7
7
#include < ATen/core/ScalarType.h>
8
8
#include < ATen/core/TensorOptions.h>
9
9
10
+ // NB: This file is compiled even in CPU build (for some reason), so
11
+ // make sure you don't include any CUDA only headers.
12
+
10
13
using namespace at ;
11
14
15
+ // TODO: This might be generally helpful aliases elsewhere.
16
+ at::Device CPUDevice (DeviceIndex index) {
17
+ return at::Device (at::kCPU );
18
+ }
19
+ at::Device CUDADevice (DeviceIndex index) {
20
+ return at::Device (at::kCUDA , index );
21
+ }
22
+
12
23
// A macro so we don't lose location information when an assertion fails.
13
24
#define REQUIRE_OPTIONS (device_, index_, type_, layout_ ) \
14
25
ASSERT_EQ (options.device().type(), Device((device_), (index_)).type()); \
@@ -54,14 +65,14 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) {
54
65
if (at::globalContext ().getNumGPUs () > 1 ) {
55
66
Tensor tensor;
56
67
{
57
- DeviceGuard guard (1 );
68
+ DeviceGuard guard (CUDADevice ( 1 ) );
58
69
tensor = empty (5 , device (kCUDA ));
59
70
}
60
71
options = tensor.options ();
61
72
REQUIRE_OPTIONS (kCUDA , 1 , kFloat , kStrided );
62
73
63
74
{
64
- DeviceGuard guard (1 );
75
+ DeviceGuard guard (CUDADevice ( 1 ) );
65
76
tensor = empty (5 , device (kCUDA ).layout (kSparse ));
66
77
}
67
78
options = tensor.options ();
@@ -94,15 +105,15 @@ TEST(OptionsGuardTest, DeviceGuardOptionsGuardInteraction_MultiCUDA) {
94
105
Tensor tensor;
95
106
{
96
107
// Check that OptionsGuard respects any active device before construction.
97
- DeviceGuard guard (1 );
108
+ DeviceGuard guard (CUDADevice ( 1 ) );
98
109
{
99
110
OptionsGuard guard (device (kCUDA ));
100
111
tensor = at::empty ({10 });
101
112
REQUIRE_TENSOR_OPTIONS (kCUDA , 1 , kFloat , kStrided );
102
113
{
103
114
// Check that OptionsGuard respects any active device after
104
115
// construction.
105
- DeviceGuard guard (0 );
116
+ DeviceGuard guard (CUDADevice ( 0 ) );
106
117
tensor = at::empty ({10 });
107
118
REQUIRE_TENSOR_OPTIONS (kCUDA , 0 , kFloat , kStrided );
108
119
{
@@ -116,7 +127,7 @@ TEST(OptionsGuardTest, DeviceGuardOptionsGuardInteraction_MultiCUDA) {
116
127
}
117
128
118
129
TEST (DeviceGuardTest, IsMovable_CUDA) {
119
- DeviceGuard first (1 );
130
+ DeviceGuard first (CUDADevice ( 1 ) );
120
131
ASSERT_EQ (first.original_index (), 0 );
121
132
ASSERT_EQ (first.last_index (), 1 );
122
133
DeviceGuard second (std::move (first));
0 commit comments