Skip to content

Commit 56bdd87

Browse files
ezyangfacebook-github-bot
authored andcommitted
Get rid of some uses of type() (pytorch#11215)
Summary: Pull Request resolved: pytorch#11215 I found these by deleting the implicit conversion of Type to TensorOptions and then fixing sites. This isn't a complete refactor, because I ran out of steam after fixing this many and decided to keep the implicit conversion. Still, why waste a perfectly good refactor? Reviewed By: gchanan, cpuhrsch Differential Revision: D9634750 fbshipit-source-id: 4d8fb778e13e6e24b888b1314a02709b2cb00b62
1 parent 9ca63c5 commit 56bdd87

File tree

10 files changed

+59
-63
lines changed

10 files changed

+59
-63
lines changed

aten/src/ATen/cudnn/Descriptors.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ struct AT_CUDA_API DropoutDescriptor
257257
AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
258258
AT_ASSERT(type.is_cuda());
259259
AT_ASSERT(type.scalarType() == kByte);
260-
state = at::empty({static_cast<int64_t>(state_size)}, type);
260+
state = at::empty({static_cast<int64_t>(state_size)}, type.options());
261261
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
262262
}
263263

aten/src/ATen/native/Unique.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
2121
const Tensor& input = self.contiguous();
2222
const scalar_t* input_data = input.data<scalar_t>();
2323
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
24-
Tensor output = at::empty({static_cast<int64_t>(set.size())}, input.type());
24+
Tensor output = at::empty({static_cast<int64_t>(set.size())}, input.options());
2525
scalar_t* output_data = output.data<scalar_t>();
2626

2727
if (sorted) {
@@ -32,7 +32,7 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
3232
std::copy(set.begin(), set.end(), output_data);
3333
}
3434

35-
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
35+
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
3636
if (return_inverse) {
3737
inverse_indices.resize_(input.sizes());
3838
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
@@ -103,12 +103,12 @@ std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
103103
return false;
104104
});
105105

106-
Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.type());
106+
Tensor input_sorted = at::empty(input_flat.sizes(), input_flat.options());
107107
for (int i = 0; i < indices.size(); ++i) {
108108
input_sorted[i] = input_flat[indices[i]];
109109
}
110110

111-
Tensor inverse_indices = at::empty(indices.size(), self.type().toScalarType(kLong));
111+
Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
112112
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
113113
auto last = _unique_dim_cpu_impl(
114114
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);

aten/src/ATen/native/cuda/EmbeddingBag.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,15 @@ Tensor embedding_bag_backward_cuda_sum_avg(
175175

176176
Tensor &bag_size = const_cast<Tensor &>(bag_size_);
177177

178-
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.type());
178+
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
179179

180180
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181181

182182
ptrdiff_t numel = indices.numel();
183183
int64_t stride = grad_weight.stride(0);
184184

185-
auto sorted_indices = indices.type().tensor(indices.sizes());
186-
auto orig_indices = indices.type().tensor(indices.sizes());
185+
auto sorted_indices = at::empty_like(indices);
186+
auto orig_indices = at::empty_like(indices);
187187
using device_ptr = thrust::device_ptr<int64_t>;
188188

189189
// Sort the inputs into sorted with the corresponding indices; we
@@ -208,7 +208,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
208208

209209
Tensor count;
210210
if (scale_grad_by_freq) {
211-
count = indices.type().tensor(indices.sizes());
211+
count = at::empty_like(indices);
212212

213213
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
214214
auto policy = thrust::cuda::par(allocator).on(stream);
@@ -278,7 +278,7 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
278278
const Tensor &max_indices,
279279
int64_t num_weights) {
280280

281-
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.type());
281+
auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
282282

283283
int64_t stride = grad_weight.stride(0);
284284

aten/src/ATen/native/cuda/LossCTC.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
185185
int64_t tg_target_stride;
186186

187187
int64_t max_target_length;
188-
auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
188+
auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
189189
auto tg_batch_offsets_data = tg_batch_offsets.data<int64_t>();
190190
if (targets.dim() == 1) { // concatenated targets
191191
int64_t pos = 0;
@@ -219,8 +219,8 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
219219
" (while checking arguments for ", c, ")");
220220
}
221221

222-
auto target_lengths_t = at::tensor(target_lengths, targets.options().device(at::Device(at::Device::Type::CPU)).dtype(kLong)).toType(targets.type().toScalarType(kLong));
223-
auto input_lengths_t = at::tensor(input_lengths, targets.options().device(at::Device(at::Device::Type::CPU)).dtype(kLong)).toType(targets.type().toScalarType(kLong));
222+
auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
223+
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
224224
tg_batch_offsets = tg_batch_offsets.toType(targets.type().toScalarType(kLong));
225225

226226
Tensor log_alpha = at::empty({batch_size, log_probs.size(0), 2*max_target_length+1}, log_probs.options());

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT
7878
LongTensor rowIndices = indices.select(0, 0);
7979
LongTensor colIndices = indices.select(0, 1);
8080
IntTensor csr = _to_csr_int(rowIndices, m, nnz);
81-
IntTensor colIndicesInt = at::empty({colIndices.size(0)}, indices.type().toScalarType(kInt));
81+
IntTensor colIndicesInt = at::empty({colIndices.size(0)}, indices.options().dtype(kInt));
8282
colIndicesInt.copy_(colIndices);
8383

8484
// No half support, so we don't have to use CUDATypeConversion
@@ -153,7 +153,7 @@ Tensor s_addmm_sparse_dense_cuda(
153153
Scalar beta,
154154
Scalar alpha
155155
) {
156-
Tensor r = t.type().tensor();
156+
Tensor r = at::empty({0}, t.options());
157157
s_addmm_out_sparse_dense_cuda(r, t, sparse, dense, beta, alpha);
158158
return r;
159159
}
@@ -208,7 +208,7 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse
208208

209209
LongTensor indices = at::empty({1, nnz}, CUDA(kLong));
210210
// create values in column-major format to avoid copying in spaddmm
211-
Tensor values = at::empty({n, nnz}, dense.type());
211+
Tensor values = at::empty({n, nnz}, dense.options());
212212
values.transpose_(0, 1);
213213

214214
// why does sparse need to be cloned? If this is really necessary maybe we
@@ -434,7 +434,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
434434
Tensor t_values_ = t._values();
435435
LongTensor s_indices_ = src._indices();
436436
Tensor s_values_ = src._values();
437-
LongTensor r_indices_ = t_indices_.type().tensor({sparseDims, max_nnz});
437+
LongTensor r_indices_ = at::empty({sparseDims, max_nnz}, t_indices_.options());
438438
Tensor r_values_ = _new_values_with_size_of(t_values_, max_nnz).zero_();
439439
r_.resize_as_(src);
440440
_get_sparse_impl(r_)->set_indices_and_values_unsafe(r_indices_, r_values_);

tools/autograd/derivatives.yaml

+6-6
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@
283283
self: grad
284284

285285
- name: gather(Tensor self, int64_t dim, Tensor index)
286-
self: at::zeros(self.sizes(), grad.type()).scatter_add_(dim, index, grad)
286+
self: at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)
287287

288288
- name: ge_(Tensor self, Scalar other)
289289
self: zeros_like(self)
@@ -346,7 +346,7 @@
346346
value: grad.index_select(dim, index).sum()
347347

348348
- name: index_select(Tensor self, int64_t dim, Tensor index)
349-
self: at::zeros(self.sizes(), grad.type()).index_add_(dim, index, grad)
349+
self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)
350350

351351
- name: inverse(Tensor self)
352352
self: -at::mm(result.t(), at::mm(grad, result.t()))
@@ -511,14 +511,14 @@
511511
self: zeros_like(grad)
512512

513513
- name: normal(Tensor mean, double std, Generator generator)
514-
mean: at::zeros(mean.sizes(), grad.type())
514+
mean: at::zeros(mean.sizes(), grad.options())
515515

516516
- name: normal(double mean, Tensor std, Generator generator)
517-
std: at::zeros(std.sizes(), grad.type())
517+
std: at::zeros(std.sizes(), grad.options())
518518

519519
- name: normal(Tensor mean, Tensor std, Generator generator)
520-
mean: at::zeros(mean.sizes(), grad.type())
521-
std: at::zeros(std.sizes(), grad.type())
520+
mean: at::zeros(mean.sizes(), grad.options())
521+
std: at::zeros(std.sizes(), grad.options())
522522

523523
- name: orgqr(Tensor self, Tensor input2)
524524
self: not_implemented("orgqr")

0 commit comments

Comments
 (0)