Skip to content

Commit d8dab6f

Browse files
ssnlfacebook-github-bot
authored andcommitted
Add tensor.to(options) (pytorch#13146)
Summary: ezyang on the template hack smessmer on SFINAE of the `TensorOptions(Device)` goldsborough on the C++ API test changes zdevito on the `jit` codegen changes Pull Request resolved: pytorch#13146 Reviewed By: ezyang Differential Revision: D12823809 Pulled By: SsnL fbshipit-source-id: 98d65c401c98fda1c6fa358e4538f86c6495abdc
1 parent 3365d74 commit d8dab6f

15 files changed

+405
-101
lines changed

aten/src/ATen/core/Tensor.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,9 @@ class CAFFE2_API Tensor {
650650
std::vector<Tensor> unbind(int64_t dim=0) const;
651651
Tensor to_sparse(int64_t sparse_dim) const;
652652
Tensor to_sparse() const;
653+
Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const;
653654
Tensor to(Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) const;
654655
Tensor to(ScalarType dtype, bool non_blocking=false, bool copy=false) const;
655-
Tensor to(Device device, bool non_blocking=false, bool copy=false) const;
656656
Tensor to(const Tensor & other, bool non_blocking=false, bool copy=false) const;
657657
Scalar _local_scalar() const;
658658
int64_t storage_offset() const;
@@ -774,6 +774,18 @@ class CAFFE2_API Tensor {
774774
Tensor remainder(Scalar other) const;
775775
Tensor remainder(const Tensor & other) const;
776776

777+
// We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the
778+
// at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet.
779+
// Before that change, we make this method to maintain BC for C++ usage like
780+
// `x.to(y.dtype)`.
781+
// TODO: remove following two after at::kDouble and its friends are TypeMeta's.
782+
inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
783+
return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
784+
}
785+
inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
786+
return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
787+
}
788+
777789
template <typename F, typename... Args>
778790
auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward<Args>(params)...)) {
779791
return func(*this, std::forward<Args>(params)...);

aten/src/ATen/core/TensorMethods.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,15 +1226,15 @@ inline Tensor Tensor::to_sparse(int64_t sparse_dim) const {
12261226
inline Tensor Tensor::to_sparse() const {
12271227
return type().to_sparse(*this);
12281228
}
1229+
inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const {
1230+
return type().to(*this, options, non_blocking, copy);
1231+
}
12291232
inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const {
12301233
return type().to(*this, device, dtype, non_blocking, copy);
12311234
}
12321235
inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const {
12331236
return type().to(*this, dtype, non_blocking, copy);
12341237
}
1235-
inline Tensor Tensor::to(Device device, bool non_blocking, bool copy) const {
1236-
return type().to(*this, device, non_blocking, copy);
1237-
}
12381238
inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) const {
12391239
return type().to(*this, other, non_blocking, copy);
12401240
}

aten/src/ATen/core/TensorOptions.h

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ATen/core/ScalarTypeUtils.h>
99

1010
#include "c10/util/Optional.h"
11+
#include "c10/util/C++17.h"
1112

1213
#include <cstddef>
1314
#include <iosfwd>
@@ -58,6 +59,51 @@ CAFFE2_API const DefaultTensorOptions& getDefaultTensorOptions();
5859
/// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1
5960
/// at::zeros({2,2}, at::requires_grad());
6061
///
62+
63+
/// NOTE [ TensorOptions Constructors ]
64+
///
65+
/// TensorOptions is like a dictionary with entries from the set:
66+
/// {requires_grad, is_variable, device, dtype, layout}, where each entry may be
67+
/// unspecified (i.e., is optional). It is used to specify the properties of
68+
/// tensors in many places both in C++ internal and API, e.g., tensor factory
69+
/// methods like `at::empty({10}, options)`, tensor conversions like
70+
/// `tensor.to(...)`, etc.
71+
///
72+
/// To provide a simple API that is consistent with Python, where one can do
73+
/// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a
74+
/// `torch.layout`, we want TensorOptions to be implicitly convertible from
75+
/// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have
76+
/// three implicit constructors from each of these three types.
77+
///
78+
/// This is sufficient for `ScalarType` and `Layout` as they are simple Enum
79+
/// classes. However, `Device` is an ordinary class with implicit constructors
80+
/// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be
81+
/// consistent with Python API, where strings are treated as equivalent with a
82+
/// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a
83+
/// `torch.device("cuda:1")` is accepted). To support the syntax
84+
/// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure
85+
/// that `TensorOptions` is implicitly constructible with any argments that a
86+
/// `Device` can constructed from. So we have,
87+
///
88+
/// /* implicit */ TensorOptions(T&& device) : TensorOptions() {
89+
/// this->set_device(device);
90+
/// }
91+
///
92+
/// template <typename... Args,
93+
/// typename = std::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
94+
/// /* implicit */ TensorOptions(Args&&... args)
95+
/// : TensorOptions(Device(std::forward<Args>(args)...)) {}
96+
///
97+
///
98+
/// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`.
99+
/// Compiler will compain about ambiguity between the copy constructor and the
100+
/// `Device` constructor because `{kCUDA, 1}` can be converted to both a
101+
/// `TensorOption` and a `Device`.
102+
///
103+
/// To get around this, we templatize the `Device` constructor. Since overload
104+
/// resolution is done before template resolution, our problem is solved.
105+
106+
61107
struct CAFFE2_API TensorOptions {
62108
TensorOptions()
63109
: requires_grad_(false)
@@ -75,20 +121,31 @@ struct CAFFE2_API TensorOptions {
75121
}
76122

77123
/// Constructs a `TensorOptions` object with the given device.
78-
/* implicit */ TensorOptions(Device device) : TensorOptions() {
79-
this->set_device(device);
124+
/// See NOTE [ TensorOptions Constructors ] on why this is templatized.
125+
template<typename T,
126+
typename = c10::guts::enable_if_t<std::is_same<c10::guts::decay_t<T>, Device>::value>>
127+
/* implicit */ TensorOptions(T&& device) : TensorOptions() {
128+
this->set_device(std::forward<T>(device));
80129
}
81130

131+
/// Constructs a `TensorOptions` object from arguments allowed in `Device`
132+
/// constructors.
133+
///
134+
/// See NOTE [ TensorOptions Constructors ].
135+
///
136+
/// NB: Ideally we only allow implicit constructors here. But there is no easy
137+
/// way to detect them. So we have this one that allows explicit
138+
/// constructors too.
139+
template <typename... Args,
140+
typename = c10::guts::enable_if_t<std::is_constructible<Device, Args&&...>::value>>
141+
/* implicit */ TensorOptions(Args&&... args)
142+
: TensorOptions(Device(std::forward<Args>(args)...)) {}
143+
82144
/// Constructs a `TensorOptions` object from a backend, forwarded to the
83145
/// `Device` constructor.
84146
/* implicit */ TensorOptions(Backend backend)
85147
: TensorOptions(Device(backendToDeviceType(backend))) {}
86148

87-
/// Constructs a `TensorOptions` object from a device type, forwarded to the
88-
/// `Device` constructor.
89-
/* implicit */ TensorOptions(DeviceType device_type)
90-
: TensorOptions(Device(device_type)) {}
91-
92149
/// Constructs a `TensorOptions` object with the given dtype.
93150
/* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
94151
this->set_dtype(dtype);

aten/src/ATen/core/Type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,9 @@ struct CAFFE2_API Type {
604604
virtual std::vector<Tensor> unbind(const Tensor & self, int64_t dim) const = 0;
605605
virtual Tensor to_sparse(const Tensor & self, int64_t sparse_dim) const = 0;
606606
virtual Tensor to_sparse(const Tensor & self) const = 0;
607+
virtual Tensor to(const Tensor & self, const TensorOptions & options, bool non_blocking, bool copy) const = 0;
607608
virtual Tensor to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy) const = 0;
608609
virtual Tensor to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy) const = 0;
609-
virtual Tensor to(const Tensor & self, Device device, bool non_blocking, bool copy) const = 0;
610610
virtual Tensor to(const Tensor & self, const Tensor & other, bool non_blocking, bool copy) const = 0;
611611
virtual Scalar _local_scalar(const Tensor & self) const = 0;
612612
virtual int64_t storage_offset(const Tensor & self) const = 0;

aten/src/ATen/core/typeid.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ class CAFFE2_API TypeMeta {
351351
private:
352352
// TypeMeta can only be created by Make, making sure that we do not
353353
// create incorrectly mixed up TypeMeta objects.
354-
constexpr TypeMeta(const detail::TypeMetaData* data) noexcept : data_(data) {}
354+
explicit constexpr TypeMeta(const detail::TypeMetaData* data) noexcept : data_(data) {}
355355

356356
public:
357357
/**

aten/src/ATen/function_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def __getitem__(self, x):
527527
('buffers', Optional[List[str]]),
528528
('returns', List[ReturnType]),
529529
('inplace', bool),
530+
('is_factory_method', bool),
530531
('abstract', bool),
531532
('requires_tensor', bool),
532533
('device_guard', bool),
@@ -924,6 +925,7 @@ def process_option(option, output_options):
924925
buffers=buffer_names,
925926
returns=option['returns'],
926927
inplace=option['inplace'],
928+
is_factory_method=False,
927929
# See Note [Abstract ATen methods]
928930
abstract=abstract,
929931
requires_tensor=option.get('requires_tensor', False),
@@ -1070,7 +1072,8 @@ def find_formal(formal_name, formals):
10701072

10711073
is_method = 'method' in option['variants']
10721074
is_namespace_function = 'function' in option['variants']
1073-
is_factory_method = find_formal('TensorOptions', formals) and not dispatch_options
1075+
is_factory_method = find_formal('TensorOptions', formals) and \
1076+
not dispatch_options and 'method' not in option['variants']
10741077
is_deprecated_factory_method = len(formals) > 0 and \
10751078
formals[0]['dynamic_type'] == 'Type' and \
10761079
option['return_type'] == 'Tensor' and option['deprecated']
@@ -1171,6 +1174,7 @@ def find_formal(formal_name, formals):
11711174
buffers=None,
11721175
returns=option['returns'],
11731176
inplace=option['inplace'],
1177+
is_factory_method=is_factory_method,
11741178
# See Note [Abstract ATen methods]
11751179
abstract=abstract,
11761180
requires_tensor=option.get('requires_tensor', False),

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,59 @@
11
#include "ATen/ATen.h"
22
#include "ATen/NativeFunctions.h"
3+
#include "c10/util/Optional.h"
34

45
namespace at {
56
namespace native {
67

7-
static void ensure_has_index(Device* device) {
8-
if (!device->is_cuda() || device->has_index()) {
9-
return;
8+
// Since the given Device may not have device_index set (i.e., having it as -1
9+
// representing the current device), we need to set the device_index before
10+
// comparing against the current device object in Tensor.
11+
// This always **copies** but this is intended because (1) we shouldn't modify
12+
// input argument, and (2) Device is small anyways.
13+
static inline Device ensure_has_index(const Device &device) {
14+
if (!device.is_cuda() || device.has_index()) {
15+
return device;
1016
}
11-
device->set_index(at::current_device());
17+
return Device(device.type(), at::current_device());
1218
}
1319

14-
static Tensor to_impl(const Tensor& self, const TensorOptions& options, bool non_blocking) {
20+
static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, bool non_blocking) {
1521
return self.type().toBackend(options.backend()).toScalarType(typeMetaToScalarType(options.dtype()))
1622
.copy(self, non_blocking, options.device());
1723
}
1824

25+
Tensor to(const Tensor& self, const TensorOptions& options, bool non_blocking, bool copy) {
26+
AT_CHECK(options.requires_grad_opt() == c10::nullopt,
27+
"to(options) expects unset requires_grad flag, but got "
28+
"options.requires_grad set as ", options.requires_grad());
29+
30+
const auto & layout_opt = options.layout_opt();
31+
AT_CHECK(!layout_opt || self.layout() == layout_opt.value(),
32+
"to(options) doesn't support converting to a different layout, "
33+
"but got self.layout being ", self.layout(),
34+
" and options.layout set as ", options.layout());
35+
36+
auto device_opt = options.device_opt();
37+
if (device_opt) {
38+
device_opt = ensure_has_index(device_opt.value());
39+
}
40+
const auto & dtype_opt = options.dtype_opt();
41+
if ((!device_opt || self.device() == device_opt.value()) &&
42+
(!dtype_opt || self.dtype() == dtype_opt.value()) && !copy) {
43+
return self;
44+
}
45+
auto specified_options = self.options();
46+
if (device_opt) {
47+
specified_options = specified_options.device(device_opt.value());
48+
}
49+
if (dtype_opt) {
50+
specified_options = specified_options.dtype(dtype_opt.value());
51+
}
52+
return to_impl(self, specified_options, non_blocking);
53+
}
54+
1955
Tensor to(const Tensor& self, Device device, ScalarType dtype, bool non_blocking, bool copy) {
20-
ensure_has_index(&device);
56+
device = ensure_has_index(device);
2157
if (self.device() == device && self.dtype() == dtype && !copy) {
2258
return self;
2359
}
@@ -31,17 +67,11 @@ Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy) {
3167
return to_impl(self, self.options().dtype(dtype), non_blocking);
3268
}
3369

34-
Tensor to(const Tensor& self, Device device, bool non_blocking, bool copy) {
35-
ensure_has_index(&device);
36-
if (self.device() == device && !copy) {
37-
return self;
38-
}
39-
return to_impl(self, self.options().device(device), non_blocking);
40-
}
41-
4270
Tensor to(const Tensor& self, const Tensor& other, bool non_blocking, bool copy) {
4371
auto self_options = self.options();
4472
auto options = other.options();
73+
// Tensor.options() always have everything filled so we are happy and don't
74+
// even need to fill in device index.
4575
if (self_options == options && !copy) {
4676
return self;
4777
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,15 +2211,18 @@
22112211
CPU: dense_to_sparse
22122212
CUDA: dense_to_sparse
22132213

2214-
- func: to(Tensor self, Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) -> Tensor
2214+
# to(Device) must not exist because all constructors of Device also works for
2215+
# TensorOptions. Otherwise, an ambiguity error is thrown.
2216+
# See NOTE [ TensorOptions Constructors ].
2217+
- func: to(Tensor self, TensorOptions options, bool non_blocking=false, bool copy=false) -> Tensor
22152218
variants: method
22162219
device_guard: False
22172220

2218-
- func: to(Tensor self, ScalarType dtype, bool non_blocking=false, bool copy=false) -> Tensor
2221+
- func: to(Tensor self, Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) -> Tensor
22192222
variants: method
22202223
device_guard: False
22212224

2222-
- func: to(Tensor self, Device device, bool non_blocking=false, bool copy=false) -> Tensor
2225+
- func: to(Tensor self, ScalarType dtype, bool non_blocking=false, bool copy=false) -> Tensor
22232226
variants: method
22242227
device_guard: False
22252228

aten/src/ATen/templates/Tensor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,18 @@ class CAFFE2_API Tensor {
261261
//Tensor * add(Tensor & b);
262262
${tensor_method_declarations}
263263

264+
// We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the
265+
// at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet.
266+
// Before that change, we make this method to maintain BC for C++ usage like
267+
// `x.to(y.dtype)`.
268+
// TODO: remove following two after at::kDouble and its friends are TypeMeta's.
269+
inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
270+
return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
271+
}
272+
inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
273+
return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
274+
}
275+
264276
template <typename F, typename... Args>
265277
auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward<Args>(params)...)) {
266278
return func(*this, std::forward<Args>(params)...);

0 commit comments

Comments
 (0)