Skip to content

Commit 0aaff5e

Browse files
ezyangfacebook-github-bot
authored andcommitted
Replace CUDA-specific set_index(_from) method from DeviceGuard with set_device. (pytorch#13275)
Summary: Pull Request resolved: pytorch#13275 This resulted in a bunch of knock-on changes, which I will now describe: - s/original_index/original_device/ - s/last_index/last_device/ - A bunch of places that used set_index, now use CUDAGuard (which does have set_index) because they were CUDA-specific code. Major caveat: DeviceGuard doesn't *actually* work non-CUDA/CPU devices, To make that happen, I plan on totally replacing the implementation of DeviceGuard; what I mostly care about here is wrangling the API into an acceptable state. Reviewed By: gchanan Differential Revision: D12832080 fbshipit-source-id: 7de068c7cec35663dc8a533026a626331336e61d
1 parent e5d5665 commit 0aaff5e

19 files changed

+115
-97
lines changed

aten/src/ATen/DeviceGuard.h

+37-28
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,38 @@
1010
#include <cstddef>
1111

1212
namespace at {
13-
/// RAII guard that sets a certain default GPU index in its constructor, and
13+
/// RAII guard that sets a certain default device in its constructor, and
1414
/// changes it back to the device that was originally active upon destruction.
1515
///
16-
/// The index is always reset to the one that was active at the time of
17-
/// construction of the guard. Even if you `set_index` after construction, the
18-
/// destructor will still reset the index to the one that was active at
16+
/// The device is always reset to the one that was active at the time of
17+
/// construction of the guard. Even if you `set_device` after construction, the
18+
/// destructor will still reset the device to the one that was active at
1919
/// construction time.
2020
struct DeviceGuard {
2121
/// Default constructor, does nothing.
2222
DeviceGuard() = default;
2323

24-
/// Uses the given device's `index()` if it is a CUDA device, else does
25-
/// nothing.
24+
/// Set the current device to the passed Device.
2625
explicit DeviceGuard(Device device) {
27-
if (device.is_cuda()) {
28-
set_index(device.index());
29-
}
26+
set_device(device);
3027
}
3128

3229
explicit DeviceGuard(c10::optional<Device> device_opt) {
33-
if (device_opt.has_value() && device_opt.value().is_cuda()) {
34-
set_index(device_opt.value().index());
30+
if (device_opt.has_value()) {
31+
set_device(device_opt.value());
3532
}
3633
}
3734

38-
/// Sets the device to the index on which the given tensor is located.
35+
/// Sets the current device to the device on which the given tensor is located.
3936
explicit DeviceGuard(const Tensor& tensor) {
40-
set_index_from(tensor);
37+
set_device_from(tensor);
4138
}
4239

43-
/// Sets the device to the index on which the first tensor in the list is
40+
/// Sets the current device to the device on which the first tensor in the list is
4441
/// located. If the list is empty, does nothing.
4542
explicit DeviceGuard(const TensorList& tensors) {
4643
if (!tensors.empty()) {
47-
set_index_from(tensors.front());
44+
set_device_from(tensors.front());
4845
}
4946
}
5047

@@ -71,7 +68,7 @@ struct DeviceGuard {
7168
return *this;
7269
}
7370

74-
/// Resets the device to the index that was active at construction of the
71+
/// Resets the device to the device that was active at construction of the
7572
/// guard.
7673
~DeviceGuard() {
7774
// It should only not have a value if an index was never actually set.
@@ -82,7 +79,12 @@ struct DeviceGuard {
8279
}
8380

8481
/// Sets the device to the given one.
85-
void set_index(int16_t index) {
82+
void set_device(at::Device device) {
83+
if (device.type() == at::kCPU) {
84+
return;
85+
}
86+
AT_ASSERT(device.type() == at::kCUDA);
87+
auto index = device.index();
8688
if (index == -1) {
8789
return;
8890
}
@@ -100,28 +102,35 @@ struct DeviceGuard {
100102
last_index_ = index;
101103
}
102104

103-
/// Calls `set_index` with the `Tensor`'s current device, if it is a CUDA
104-
/// tensor. Does nothing if the `tensor` is not defined.
105-
void set_index_from(const Tensor& tensor) {
106-
if (tensor.defined() && tensor.is_cuda()) {
107-
set_index(tensor.get_device());
105+
/// Calls `set_device` with the `Tensor`'s current device, if it is not a
106+
/// CPU tensor. Does nothing if the `tensor` is not defined.
107+
void set_device_from(const Tensor& tensor) {
108+
if (tensor.defined()) {
109+
set_device(tensor.device());
108110
}
109111
}
110112

111113
/// Returns the device that was set upon construction of the guard.
112-
int16_t original_index() const noexcept {
113-
return original_index_;
114+
at::Device original_device() const noexcept {
115+
return original_index_ == -1 ? at::kCPU : at::Device(at::kCUDA, original_index_);
114116
}
115117

116-
/// Returns the last device that was set via `set_index`, if any.
117-
int16_t last_index() const noexcept {
118-
return last_index_;
118+
/// Returns the last device that was set via `set_device`, if any.
119+
at::Device last_device() const noexcept {
120+
return last_index_ == -1 ? at::kCPU : at::Device(at::kCUDA, last_index_);
119121
}
120122

121123
private:
124+
// This representation only works under the assumption that the DeviceType
125+
// is only CUDA. I think a reasonable invariant to assert for DeviceGuard
126+
// is that once you've "picked" a device type, you can't mix set_device
127+
// with other device types.
128+
122129
/// The original device that was active at construction of this object.
130+
/// If not -1, it is a CUDA device.
123131
int16_t original_index_ = -1;
124-
/// The last index that was set via `set_index`.
132+
/// The last device that was set via `set_device`. If not -1, it is a CUDA
133+
/// device.
125134
int16_t last_index_ = -1;
126135
};
127136
} // namespace at

aten/src/ATen/cuda/CUDAGuard.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ struct CUDAGuard {
7676

7777
/// Sets the CUDA device to the given one.
7878
/// TODO: Deprecate this name
79-
void set_device(int32_t device) {
80-
device_guard_.set_index(device);
79+
void set_device(int32_t device_index) {
80+
set_index(device_index);
8181
}
8282

8383
/// Sets the CUDA device to the given one.
84-
void set_index(int32_t device) {
85-
device_guard_.set_index(device);
84+
void set_index(int32_t device_index) {
85+
device_guard_.set_device(at::Device(at::kCUDA, device_index));
8686
}
8787

8888
/// Returns the CUDA streams that were active in the first call to
@@ -93,13 +93,13 @@ struct CUDAGuard {
9393
}
9494

9595
/// Returns the device that was set upon construction of the guard.
96-
int32_t original_device() const noexcept {
97-
return device_guard_.original_index();
96+
Device original_device() const noexcept {
97+
return device_guard_.original_device();
9898
}
9999

100100
/// Returns the last device that was set via `set_device`, if any.
101-
int32_t last_device() const noexcept {
102-
return device_guard_.last_index();
101+
Device last_device() const noexcept {
102+
return device_guard_.last_device();
103103
}
104104

105105
private:

aten/src/ATen/native/TensorConversions.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace native {
1010
// comparing against the current device object in Tensor.
1111
// This always **copies** but this is intended because (1) we shouldn't modify
1212
// input argument, and (2) Device is small anyways.
13+
// NB: This ONLY works for CUDA device
1314
static inline Device ensure_has_index(const Device &device) {
1415
if (!device.is_cuda() || device.has_index()) {
1516
return device;

aten/src/ATen/templates/TypeDefault.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ Tensor & TypeDefault::copy_(Tensor & self, const Tensor & src, bool non_blocking
2828
}
2929

3030
Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device> to_device) const {
31-
DeviceGuard device_guard;
32-
if (to_device.has_value()) {
33-
device_guard.set_index(to_device.value().index());
34-
}
31+
DeviceGuard device_guard(to_device);
3532
AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
3633
Tensor r;
3734
if (is_sparse()) {

aten/src/ATen/test/stream_test.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ TEST(TestStream, CUDAGuardTest) {
152152
// Setting a stream changes the current device and the stream on that device
153153
{
154154
at::cuda::CUDAGuard guard(streams1[1]);
155-
ASSERT_EQ_CUDA(guard.last_device(), 1);
155+
ASSERT_EQ_CUDA(guard.last_device(), at::Device(at::kCUDA, 1));
156156
ASSERT_EQ_CUDA(at::cuda::current_device(), 1);
157157
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(1), streams1[1]);
158158
}
@@ -164,7 +164,7 @@ TEST(TestStream, CUDAGuardTest) {
164164
// Setting only the device changes only the current device and not the stream
165165
{
166166
at::cuda::CUDAGuard guard(/*device=*/1);
167-
ASSERT_EQ_CUDA(guard.last_device(), 1);
167+
ASSERT_EQ_CUDA(guard.last_device(), at::Device(at::kCUDA, 1));
168168
ASSERT_EQ_CUDA(at::cuda::current_device(), 1);
169169
ASSERT_EQ_CUDA(at::cuda::getCurrentCUDAStream(1), streams1[0]);
170170
}
@@ -196,13 +196,13 @@ TEST(TestStream, CUDAGuardMovableTest) {
196196
first.set_device(1);
197197
at::cuda::CUDAGuard second(std::move(first));
198198
ASSERT_EQ_CUDA(second.original_streams().size(), device_count);
199-
ASSERT_EQ_CUDA(second.original_device(), 0);
200-
ASSERT_EQ_CUDA(second.last_device(), 1);
199+
ASSERT_EQ_CUDA(second.original_device(), at::Device(at::kCUDA, 0));
200+
ASSERT_EQ_CUDA(second.last_device(), at::Device(at::kCUDA, 1));
201201
at::cuda::CUDAGuard third;
202202
third = std::move(second);
203203
ASSERT_EQ_CUDA(third.original_streams().size(), device_count);
204-
ASSERT_EQ_CUDA(third.original_device(), 0);
205-
ASSERT_EQ_CUDA(third.last_device(), 1);
204+
ASSERT_EQ_CUDA(third.original_device(), at::Device(at::kCUDA, 0));
205+
ASSERT_EQ_CUDA(third.last_device(), at::Device(at::kCUDA, 1));
206206
}
207207

208208
// Streampool Round Robin

test/cpp/api/tensor_options_cuda.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using namespace at;
1414

1515
// TODO: This might be generally helpful aliases elsewhere.
16-
at::Device CPUDevice(DeviceIndex index) {
16+
at::Device CPUDevice() {
1717
return at::Device(at::kCPU);
1818
}
1919
at::Device CUDADevice(DeviceIndex index) {
@@ -128,15 +128,15 @@ TEST(OptionsGuardTest, DeviceGuardOptionsGuardInteraction_MultiCUDA) {
128128

129129
TEST(DeviceGuardTest, IsMovable_CUDA) {
130130
DeviceGuard first(CUDADevice(1));
131-
ASSERT_EQ(first.original_index(), 0);
132-
ASSERT_EQ(first.last_index(), 1);
131+
ASSERT_EQ(first.original_device(), CUDADevice(0));
132+
ASSERT_EQ(first.last_device(), CUDADevice(1));
133133
DeviceGuard second(std::move(first));
134-
ASSERT_EQ(second.original_index(), 0);
135-
ASSERT_EQ(second.last_index(), 1);
136-
ASSERT_EQ(first.original_index(), -1);
134+
ASSERT_EQ(second.original_device(), CUDADevice(0));
135+
ASSERT_EQ(second.last_device(), CUDADevice(1));
136+
ASSERT_EQ(first.original_device(), CPUDevice());
137137
DeviceGuard third;
138138
third = std::move(second);
139-
ASSERT_EQ(third.original_index(), 0);
140-
ASSERT_EQ(third.last_index(), 1);
141-
ASSERT_EQ(second.original_index(), -1);
139+
ASSERT_EQ(third.original_device(), CUDADevice(0));
140+
ASSERT_EQ(third.last_device(), CUDADevice(1));
141+
ASSERT_EQ(second.original_device(), CPUDevice());
142142
}

torch/csrc/autograd/profiler.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "torch/csrc/autograd/profiler.h"
22
#include "torch/csrc/autograd/function.h"
33

4+
#ifdef USE_CUDA
5+
#include "ATen/cuda/CUDAGuard.h"
6+
#endif
7+
48
#include <sstream>
59

610
namespace torch { namespace autograd { namespace profiler {
@@ -122,7 +126,7 @@ RecordFunction::RecordFunction(const char* name, int64_t current_sequence_nr)
122126

123127
#ifdef USE_CUDA
124128
static void onEachDevice(std::function<void(int)> op) {
125-
at::DeviceGuard device_guard;
129+
at::cuda::CUDAGuard device_guard;
126130
int count;
127131
TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
128132
for(int i = 0; i < count; i++) {

torch/csrc/autograd/python_function.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,13 @@ namespace torch { namespace autograd {
4545

4646
VariableInfo::VariableInfo(const Variable& var)
4747
: type(&var.type())
48+
, device(var.device())
4849
, size(var.sizes().vec())
4950
, requires_grad(var.requires_grad()) {
50-
if (var.type().is_cuda()) {
51-
device = var.get_device();
52-
}
5351
}
5452

5553
Variable VariableInfo::zeros(at::DeviceGuard& device_guard) const {
56-
device_guard.set_index(device);
54+
device_guard.set_device(device);
5755
return at::zeros(size, type->options());
5856
}
5957

torch/csrc/autograd/python_function.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct VariableInfo {
2727
Variable zeros(at::DeviceGuard& device_guard) const;
2828

2929
at::Type* type;
30-
int32_t device = -1;
30+
at::Device device = at::kCPU;
3131
std::vector<int64_t> size;
3232
bool requires_grad;
3333
};

torch/csrc/cuda/nccl.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "torch/csrc/utils/hash.h"
55

66
#include <ATen/ATen.h>
7+
#include <ATen/cuda/CUDAGuard.h>
78
#include <c10/util/Exception.h>
89

910
#include <THC/THC.h>
@@ -241,7 +242,7 @@ void broadcast(
241242
const auto comms = user_comms.empty() ? _get_communicators(tensors)
242243
: ArrayRef<ncclComm_t>(user_comms);
243244

244-
at::DeviceGuard device_guard;
245+
at::cuda::CUDAGuard device_guard;
245246
AutoNcclGroup nccl_group_guard;
246247
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
247248
int device = tensors[i].get_device();
@@ -288,7 +289,7 @@ void reduce(
288289
auto comms_ref = user_comms.empty() ? _get_communicators(inputs)
289290
: ArrayRef<ncclComm_t>(user_comms);
290291

291-
at::DeviceGuard device_guard;
292+
at::cuda::CUDAGuard device_guard;
292293
AutoNcclGroup nccl_group_guard;
293294
for (size_t i = 0; i < len; i++) {
294295
int device = inputs[i].device().index();

torch/csrc/cuda/python_nccl.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "torch/csrc/cuda/nccl.h"
1010
#include "torch/csrc/utils/functional.h"
1111

12+
#include <ATen/cuda/CUDAGuard.h>
13+
1214
#include <nccl.h>
1315

1416
#include <sstream>
@@ -192,7 +194,7 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
192194
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
193195
auto comms = user_comms.empty() ? _get_communicators(inputs)
194196
: ArrayRef<ncclComm_t>(user_comms);
195-
at::DeviceGuard device_guard;
197+
at::cuda::CUDAGuard device_guard;
196198
AutoNcclGroup nccl_group_guard;
197199
for (size_t i = 0; i < len; i++) {
198200
int device = inputs[i].get_device();
@@ -272,7 +274,7 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
272274
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
273275
auto comms = user_comms.empty() ? _get_communicators(inputs)
274276
: ArrayRef<ncclComm_t>(user_comms);
275-
at::DeviceGuard device_guard;
277+
at::cuda::CUDAGuard device_guard;
276278
AutoNcclGroup nccl_group_guard;
277279
for (size_t i = 0; i < len; i++) {
278280
int device = inputs[i].get_device();
@@ -335,7 +337,7 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
335337
std::lock_guard<std::mutex> lock(*(THCCachingAllocator_getCudaFreeMutex()));
336338
auto comms = user_comms.empty() ? _get_communicators(inputs)
337339
: ArrayRef<ncclComm_t>(user_comms);
338-
at::DeviceGuard device_guard;
340+
at::cuda::CUDAGuard device_guard;
339341
AutoNcclGroup nccl_group_guard;
340342
for (size_t i = 0; i < len; i++) {
341343
int device = inputs[i].get_device();

torch/csrc/distributed/c10d/ddp.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ void syncReduction(
183183

184184
// Now make the BW stream wait on it
185185
auto bwDevice = cudaGuard.original_device();
186-
auto bwStream = cudaGuard.original_streams()[bwDevice];
186+
AT_ASSERT(bwDevice.type() == at::kCUDA);
187+
auto bwStream = cudaGuard.original_streams()[bwDevice.index()];
187188

188189
// Now let the BW stream wait for the worker stream
189190
event.block(bwStream);

torch/csrc/utils/tensor_new.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ Tensor new_with_tensor_copy(const Type& type, Tensor other, int32_t device_index
9494
AutoNoGIL no_gil;
9595
at::DeviceGuard device_guard;
9696
if (type.is_cuda()) {
97-
device_guard.set_index(device_index);
97+
// TODO: It would be better if new_with_tensor_copy took an at::Device
98+
// to begin with, but then we need to fix the situation with
99+
// dispatch_type_conversion bleggg
100+
device_guard.set_device(at::Device(at::kCUDA, device_index));
98101
}
99102
return type.copy(other);
100103
}

0 commit comments

Comments
 (0)