10
10
#include < cstddef>
11
11
12
12
namespace at {
13
- // / RAII guard that sets a certain default GPU index in its constructor, and
13
+ // / RAII guard that sets a certain default device in its constructor, and
14
14
// / changes it back to the device that was originally active upon destruction.
15
15
// /
16
- // / The index is always reset to the one that was active at the time of
17
- // / construction of the guard. Even if you `set_index ` after construction, the
18
- // / destructor will still reset the index to the one that was active at
16
+ // / The device is always reset to the one that was active at the time of
17
+ // / construction of the guard. Even if you `set_device ` after construction, the
18
+ // / destructor will still reset the device to the one that was active at
19
19
// / construction time.
20
20
struct DeviceGuard {
21
21
// / Default constructor, does nothing.
22
22
DeviceGuard () = default ;
23
23
24
- // / Uses the given device's `index()` if it is a CUDA device, else does
25
- // / nothing.
24
+ // / Set the current device to the passed Device.
26
25
explicit DeviceGuard (Device device) {
27
- if (device.is_cuda ()) {
28
- set_index (device.index ());
29
- }
26
+ set_device (device);
30
27
}
31
28
32
29
explicit DeviceGuard (c10::optional<Device> device_opt) {
33
- if (device_opt.has_value () && device_opt. value (). is_cuda () ) {
34
- set_index (device_opt.value (). index ());
30
+ if (device_opt.has_value ()) {
31
+ set_device (device_opt.value ());
35
32
}
36
33
}
37
34
38
- // / Sets the device to the index on which the given tensor is located.
35
+ // / Sets the current device to the device on which the given tensor is located.
39
36
explicit DeviceGuard (const Tensor& tensor) {
40
- set_index_from (tensor);
37
+ set_device_from (tensor);
41
38
}
42
39
43
- // / Sets the device to the index on which the first tensor in the list is
40
+ // / Sets the current device to the device on which the first tensor in the list is
44
41
// / located. If the list is empty, does nothing.
45
42
explicit DeviceGuard (const TensorList& tensors) {
46
43
if (!tensors.empty ()) {
47
- set_index_from (tensors.front ());
44
+ set_device_from (tensors.front ());
48
45
}
49
46
}
50
47
@@ -71,7 +68,7 @@ struct DeviceGuard {
71
68
return *this ;
72
69
}
73
70
74
- // / Resets the device to the index that was active at construction of the
71
+ // / Resets the device to the device that was active at construction of the
75
72
// / guard.
76
73
~DeviceGuard () {
77
74
// It should only not have a value if an index was never actually set.
@@ -82,7 +79,12 @@ struct DeviceGuard {
82
79
}
83
80
84
81
// / Sets the device to the given one.
85
- void set_index (int16_t index) {
82
+ void set_device (at::Device device) {
83
+ if (device.type () == at::kCPU ) {
84
+ return ;
85
+ }
86
+ AT_ASSERT (device.type () == at::kCUDA );
87
+ auto index = device.index ();
86
88
if (index == -1 ) {
87
89
return ;
88
90
}
@@ -100,28 +102,35 @@ struct DeviceGuard {
100
102
last_index_ = index ;
101
103
}
102
104
103
- // / Calls `set_index ` with the `Tensor`'s current device, if it is a CUDA
104
- // / tensor. Does nothing if the `tensor` is not defined.
105
- void set_index_from (const Tensor& tensor) {
106
- if (tensor.defined () && tensor. is_cuda () ) {
107
- set_index (tensor.get_device ());
105
+ // / Calls `set_device ` with the `Tensor`'s current device, if it is not a
106
+ // / CPU tensor. Does nothing if the `tensor` is not defined.
107
+ void set_device_from (const Tensor& tensor) {
108
+ if (tensor.defined ()) {
109
+ set_device (tensor.device ());
108
110
}
109
111
}
110
112
111
113
// / Returns the device that was set upon construction of the guard.
112
- int16_t original_index () const noexcept {
113
- return original_index_;
114
+ at::Device original_device () const noexcept {
115
+ return original_index_ == - 1 ? at:: kCPU : at::Device (at:: kCUDA , original_index_) ;
114
116
}
115
117
116
- // / Returns the last device that was set via `set_index `, if any.
117
- int16_t last_index () const noexcept {
118
- return last_index_;
118
+ // / Returns the last device that was set via `set_device `, if any.
119
+ at::Device last_device () const noexcept {
120
+ return last_index_ == - 1 ? at:: kCPU : at::Device (at:: kCUDA , last_index_) ;
119
121
}
120
122
121
123
private:
124
+ // This representation only works under the assumption that the DeviceType
125
+ // is only CUDA. I think a reasonable invariant to assert for DeviceGuard
126
+ // is that once you've "picked" a device type, you can't mix set_device
127
+ // with other device types.
128
+
122
129
// / The original device that was active at construction of this object.
130
+ // / If not -1, it is a CUDA device.
123
131
int16_t original_index_ = -1 ;
124
- // / The last index that was set via `set_index`.
132
+ // / The last device that was set via `set_device`. If not -1, it is a CUDA
133
+ // / device.
125
134
int16_t last_index_ = -1 ;
126
135
};
127
136
} // namespace at
0 commit comments