Skip to content

Commit 3888555

Browse files
Skylion007pytorchmergebot
authored andcommitted
Apply some more missing moves in aten native (pytorch#92983)
Add some additional missing moves to further improve vmap and related operators. Pull Request resolved: pytorch#92983 Approved by: https://github.com/ezyang
1 parent 7e449e8 commit 3888555

11 files changed

+48
-32
lines changed

aten/src/ATen/FunctionalInverses.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const
172172

173173
Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) {
174174
// Pessimism: we can't reapply views for slice_scatter.
175-
return base.select_scatter_symint(mutated_view, dim, index);
175+
return base.select_scatter_symint(mutated_view, dim, std::move(index));
176176
}
177177

178178
Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {

aten/src/ATen/functorch/BatchRulesBinaryOps.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <ATen/Operators.h>
1010
#include <ATen/core/dispatch/Dispatcher.h>
1111

12+
#include <utility>
13+
1214
namespace at { namespace functorch {
1315

1416
template <typename F, F Func, typename... ExtraArgs>
@@ -306,7 +308,7 @@ std::tuple<Tensor, optional<int64_t>> log_sigmoid_backward_batch_rule(
306308
}
307309

308310
Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
309-
return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
311+
return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous
310312
}
311313

312314
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {

aten/src/ATen/functorch/BatchRulesHelper.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <ATen/core/dispatch/Dispatcher.h>
2020
#include <ATen/VmapGeneratedPlumbing.h>
2121

22+
#include <utility>
23+
2224
// This file contains helper functions for batching rules.
2325

2426
namespace at { namespace functorch {
@@ -339,15 +341,15 @@ inline void boxed_all_tensors_have_optional_bdim(
339341
if (tensor_idx == contig_tensor_index) {
340342
value_ = value_.contiguous();
341343
}
342-
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
344+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
343345
continue;
344346
}
345347
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
346348
value_ = reshape_dim_into(*bdim, 0, value_);
347349
if (tensor_idx == contig_tensor_index) {
348350
value_ = value_.contiguous();
349351
}
350-
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
352+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
351353
}
352354

353355
op.callBoxed(stack);

aten/src/ATen/functorch/BatchRulesViews.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <ATen/functorch/BatchRulesHelper.h>
88
#include <iostream>
9+
#include <utility>
910

1011
#include <ATen/Operators.h>
1112
#include <ATen/functorch/PlumbingHelper.h>
@@ -236,7 +237,7 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
236237
}
237238

238239
auto result = self.view(squeezed_sizes);
239-
return std::make_tuple(result, c10::optional<int64_t>(new_batch_idx));
240+
return std::make_tuple(std::move(result), c10::optional<int64_t>(new_batch_idx));
240241
}
241242

242243
std::tuple<Tensor, optional<int64_t>> squeeze_dims_batch_rule(
@@ -284,13 +285,13 @@ std::tuple<std::vector<Tensor>, optional<int64_t>> chunk_batching_rule(const Ten
284285

285286
std::tuple<Tensor, optional<int64_t>> select_batching_rule(const Tensor& self, optional<int64_t> bdim, int64_t dim, c10::SymInt index) {
286287
if (!bdim) {
287-
return std::make_tuple(self.select_symint(dim, index), nullopt);
288+
return std::make_tuple(self.select_symint(dim, std::move(index)), nullopt);
288289
}
289290

290291
auto _self = moveBatchDimToFront(self, bdim);
291292
auto dim_physical = getPhysicalDim(_self, true, dim);
292-
auto result = _self.select_symint(dim_physical, index);
293-
return std::make_tuple(result, 0);
293+
auto result = _self.select_symint(dim_physical, std::move(index));
294+
return std::make_tuple(std::move(result), 0);
294295
}
295296

296297
std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, optional<int64_t> bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) {
@@ -359,8 +360,8 @@ std::tuple<Tensor,optional<int64_t>> slice_batch_rule(
359360
auto self_ = moveBatchDimToFront(self, self_bdim);
360361
dim = getPhysicalDim(self, self_bdim.has_value(), dim);
361362

362-
auto result = self_.slice_symint(dim, start, end, step);
363-
return std::make_tuple(result, 0);
363+
auto result = self_.slice_symint(dim, std::move(start), std::move(end), std::move(step));
364+
return std::make_tuple(std::move(result), 0);
364365
}
365366

366367
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
@@ -386,7 +387,7 @@ transpose_int_batch_rule(
386387
dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0);
387388
dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1);
388389
auto result = self_.transpose(dim0, dim1);
389-
return std::make_tuple(result, 0);
390+
return std::make_tuple(std::move(result), 0);
390391
}
391392

392393
std::tuple<Tensor, optional<int64_t>> permute_batching_rule(
@@ -416,7 +417,7 @@ std::tuple<Tensor,optional<int64_t>> select_backward_batch_rule(
416417
c10::SymDimVector input_sizes_(input_sizes.size() + 1);
417418
input_sizes_[0] = grad_input_.sym_size(0);
418419
std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
419-
auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, index);
420+
auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, std::move(index));
420421
return std::make_tuple(std::move(result), 0);
421422
}
422423

@@ -429,7 +430,7 @@ std::tuple<Tensor,optional<int64_t>> slice_backward_batch_rule(
429430
c10::SymDimVector input_sizes_(input_sizes.size() + 1);
430431
input_sizes_[0] = grad_input_.size(0);
431432
std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
432-
auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, start, end, step);
433+
auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, std::move(start), std::move(end), std::move(step));
433434
return std::make_tuple(std::move(result), 0);
434435
}
435436

@@ -507,7 +508,7 @@ std::tuple<Tensor, optional<int64_t>> unfold_batch_rule(
507508
if (logical_rank==0) {
508509
result = result.squeeze(-1);
509510
}
510-
return std::make_tuple(result, 0);
511+
return std::make_tuple(std::move(result), 0);
511512
}
512513

513514
std::tuple<Tensor, optional<int64_t>> narrow_copy_batch_rule(
@@ -517,9 +518,9 @@ std::tuple<Tensor, optional<int64_t>> narrow_copy_batch_rule(
517518
auto self_ = moveBatchDimToFront(self, self_bdim);
518519
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
519520
dim = maybe_wrap_dim(dim, logical_rank) + 1;
520-
auto result = self_.narrow_copy_symint(dim, start, length);
521+
auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length));
521522

522-
return std::make_tuple(result, 0);
523+
return std::make_tuple(std::move(result), 0);
523524
}
524525

525526
std::tuple<std::vector<Tensor>, optional<int64_t>> unsafe_split_batch_rule(
@@ -531,8 +532,8 @@ std::tuple<std::vector<Tensor>, optional<int64_t>> unsafe_split_batch_rule(
531532
auto self_ = moveBatchDimToFront(self, self_bdim);
532533
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
533534
dim = maybe_wrap_dim(dim, logical_rank) + 1;
534-
auto result = self_.unsafe_split_symint(split_size, dim);
535-
return std::make_tuple(result, 0);
535+
auto result = self_.unsafe_split_symint(std::move(split_size), dim);
536+
return std::make_tuple(std::move(result), 0);
536537
}
537538

538539
std::tuple<Tensor, optional<int64_t>> movedim_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef source, IntArrayRef destination) {

aten/src/ATen/functorch/Interpreter.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <ATen/functorch/ADInterpreters.h>
77
#include <ATen/functorch/DynamicLayer.h>
88

9+
#include <utility>
10+
911
namespace at { namespace functorch {
1012

1113
static DispatchKeySet get_all_dynlayer_keyset() {
@@ -92,7 +94,7 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
9294
auto result = unwrapIfDead(tensor);
9395
auto* wrapper = maybeGetTensorWrapper(result);
9496
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
95-
auto* batched = maybeGetBatchedImpl(result);
97+
auto* batched = maybeGetBatchedImpl(std::move(result));
9698
TORCH_INTERNAL_ASSERT(batched == nullptr);
9799
return tensor;
98100
});

aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <ATen/functorch/BatchedFallback.h>
1717
#include <ATen/functorch/BatchRulesHelper.h>
1818

19+
#include <utility>
20+
1921
namespace at {
2022
namespace functorch {
2123

@@ -476,7 +478,7 @@ Tensor as_strided_batching_rule(
476478
optional<c10::SymInt> storage_offset) {
477479
if (!participatesInCurrentLevel(tensor)) {
478480
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
479-
return at::as_strided_symint(tensor, sizes, strides, storage_offset);
481+
return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset));
480482
}
481483
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
482484
auto num_batch_dims = physical_view.numBatchDims();
@@ -511,7 +513,7 @@ Tensor as_strided_batching_rule(
511513
// and creates a tensor y such that each y[i] references the same memory
512514
// locations as zi. See NOTE: [When will the as_strided batching rule fail?]
513515
auto result = physical_view.tensor().as_strided_symint(
514-
physical_sizes, physical_strides, storage_offset);
516+
physical_sizes, physical_strides, std::move(storage_offset));
515517
return physical_view.getPhysicalToLogicalMap().apply(result);
516518
}
517519

aten/src/ATen/native/ComplexHelper.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#else
99
#include <ATen/ops/view_as_real_native.h>
1010
#include <ATen/ops/view_as_complex_native.h>
11+
12+
#include <utility>
1113
#endif
1214

1315
// WARNING: this header contains non-inline functions and should be only
@@ -47,7 +49,7 @@ Tensor _view_as_real_physical(const Tensor& self) {
4749
auto new_strides = computeStrideForViewAsReal(self.sym_strides());
4850
auto new_storage_offset = self.sym_storage_offset() * 2;
4951
const auto float_type = c10::toRealValueType(self.scalar_type());
50-
auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides);
52+
auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
5153
return real_tensor;
5254
}
5355

aten/src/ATen/native/Normalization.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
#include <ATen/ops/sqrt.h>
4949
#endif
5050

51-
#include <vector>
5251
#include <c10/core/SymIntArrayRef.h>
52+
#include <utility>
53+
#include <vector>
5354

5455
static const int MIOPEN_DIM_MAX = 5;
5556

@@ -490,7 +491,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
490491
auto options = input.options().dtype(
491492
at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
492493
auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
493-
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
494+
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options);
494495

495496
// don't return view of input, don't return empty tensor because it will break gradient chain
496497
auto out = input.clone();
@@ -514,7 +515,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
514515
check_dims_match_num_input_features("weight", num_features, weight.sym_numel());
515516
}
516517
if (bias.defined()) {
517-
check_dims_match_num_input_features("bias", num_features, bias.sym_numel());
518+
check_dims_match_num_input_features("bias", std::move(num_features), bias.sym_numel());
518519
}
519520

520521
const bool use_cudnn = (
@@ -672,7 +673,7 @@ Tensor instance_norm(
672673
at::alias(running_mean).copy_(running_mean_.view_symint({ b, c }).mean(0, false));
673674
}
674675
if (running_var.defined()) {
675-
at::alias(running_var).copy_(running_var_.view_symint({ b, c }).mean(0, false));
676+
at::alias(running_var).copy_(running_var_.view_symint({ std::move(b), std::move(c) }).mean(0, false));
676677
}
677678

678679
return out.view_symint(input.sym_sizes());

aten/src/ATen/native/Pool.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <ATen/native/DispatchStub.h>
55
#include <c10/util/irange.h>
66

7+
#include <utility>
8+
79
#pragma once
810

911
namespace at {
@@ -93,7 +95,7 @@ inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
9395

9496
inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
9597
c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) {
96-
return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
98+
return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), stride, dilation);
9799
}
98100

99101
// AveragePool2d/DilatedMaxPool2d (forward)

aten/src/ATen/native/Resize.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
#include <c10/core/CPUAllocator.h>
99

10+
#include <utility>
11+
1012

1113
namespace at { namespace native {
1214

@@ -130,7 +132,7 @@ static inline void checkSetStorage(Tensor& result, Storage storage, T storage_of
130132
"Attempted to set the storage of a tensor on device \"", result.storage().device(),
131133
"\" to a storage on different device \"", storage.device(),
132134
"\". This is no longer allowed; the devices must match.");
133-
result.unsafeGetTensorImpl()->set_storage_keep_dtype(storage);
135+
result.unsafeGetTensorImpl()->set_storage_keep_dtype(std::move(storage));
134136
}
135137

136138
// storageOffset

aten/src/ATen/native/TensorShape.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1804,7 +1804,7 @@ Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) {
18041804

18051805
Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
18061806
auto grad_input = at::zeros_symint(input_sizes, grad.options());
1807-
grad_input.select_symint(dim, index).copy_(grad);
1807+
grad_input.select_symint(dim, std::move(index)).copy_(grad);
18081808
return grad_input;
18091809
}
18101810

@@ -3879,7 +3879,7 @@ at::Tensor clone_preserve_strides(const at::Tensor& self) {
38793879
auto nbytes = self.storage().sym_nbytes();
38803880
TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
38813881
auto numel = nbytes / dtype_size;
3882-
auto self_full_size = self.as_strided_symint({numel}, {1}, 0);
3882+
auto self_full_size = self.as_strided_symint({std::move(numel)}, {1}, 0);
38833883
auto clone = self_full_size.clone();
38843884
auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
38853885
return out;
@@ -3896,7 +3896,7 @@ at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t
38963896
}
38973897
at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
38983898
auto output = clone_preserve_strides(self);
3899-
auto slice = output.select_symint(dim, index);
3899+
auto slice = output.select_symint(dim, std::move(index));
39003900
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
39013901
slice.copy_(src);
39023902
return output;
@@ -4039,7 +4039,7 @@ at::Tensor& _reshape_alias_copy_out(const at::Tensor & self, at::IntArrayRef siz
40394039

40404040

40414041
at::Tensor& select_copy_symint_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) {
4042-
auto tmp = self.select_symint(dim, index);
4042+
auto tmp = self.select_symint(dim, std::move(index));
40434043
out.copy_(tmp);
40444044
return out;
40454045
}

0 commit comments

Comments
 (0)