Skip to content

Commit e26cb06

Browse files
peterbell10pytorchmergebot
authored andcommitted
squeeze: allow squeezing multiple dimensions at once (pytorch#89017)
Ref pytorch#70924 This addresses part 1 of the issue, allowing `torch.squeeze` to be passed a tuple of dimensions. e.g. ```python x.squeeze(0).squeeze(0) ``` can now be written ```python x.squeeze((0, 1)) ``` (assuming x has at least 2 dimensions) Pull Request resolved: pytorch#89017 Approved by: https://github.com/albanD
1 parent 3120054 commit e26cb06

20 files changed

+347
-132
lines changed

aten/src/ATen/FunctionalInverses.cpp

+21-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <ATen/ATen.h>
55
#include <ATen/ExpandUtils.h>
6+
#include <ATen/WrapDimUtilsMulti.h>
67

78
#include <utility>
89
namespace at {
@@ -42,18 +43,26 @@ Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, bool re
4243
return result;
4344
}
4445

45-
Tensor unsqueeze_copy_to(const Tensor & self, int64_t dim, c10::SymIntArrayRef sizes, bool reapply_views) {
46-
dim = at::maybe_wrap_dim(dim, sizes.size());
46+
Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) {
47+
const auto ndim = sizes.size();
48+
const auto mask = at::dim_list_to_bitset(dim, ndim);
4749
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
4850
// unsqueezing in the backward.
49-
if (sizes.size() > 0 && sizes[dim] == 1) {
50-
if (reapply_views) {
51-
return at::unsqueeze(self, dim);
52-
} else {
53-
return at::unsqueeze_copy(self, dim);
51+
if (ndim == 0) {
52+
return self;
53+
}
54+
55+
Tensor result = self;
56+
for (const auto d : c10::irange(ndim)) {
57+
if (mask.test(d) && sizes[d] == 1) {
58+
if (reapply_views) {
59+
result = at::unsqueeze(result, d);
60+
} else {
61+
result = at::unsqueeze_copy(result, d);
62+
}
5463
}
5564
}
56-
return self;
65+
return result;
5766
}
5867

5968
// Note [Functionalization Pass: View Inverses].
@@ -215,6 +224,10 @@ Tensor FunctionalInverses::squeeze_copy_dim_inverse(const Tensor& base, const Te
215224
return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
216225
}
217226

227+
Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) {
228+
return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views);
229+
}
230+
218231
Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
219232
if (reapply_views) {
220233
return at::t(mutated_view);

aten/src/ATen/LegacyBatchingRegistrations.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,13 @@ Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
296296
return self_physical.getPhysicalToLogicalMap().apply(result);
297297
}
298298

299+
Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) {
300+
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
301+
auto dims_physical = self_physical.getPhysicalDims(dims);
302+
auto result = self_physical.tensor().squeeze(dims_physical);
303+
return self_physical.getPhysicalToLogicalMap().apply(result);
304+
}
305+
299306
Tensor trace_batching_rule(const Tensor& self) {
300307
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
301308
// Batched Diagonal View
@@ -1116,6 +1123,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
11161123
m.impl("split_with_sizes", split_with_sizes_batching_rule);
11171124
m.impl("squeeze", squeeze_batching_rule);
11181125
m.impl("squeeze.dim", squeeze_dim_batching_rule);
1126+
m.impl("squeeze.dims", squeeze_dims_batching_rule);
11191127
m.impl("t", native::t); // composite wrt autograd
11201128
m.impl("trace", trace_batching_rule);
11211129
m.impl("transpose.int", transpose_int_batching_rule);

aten/src/ATen/NamedTensorUtils.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,20 @@ std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
241241
return outnames;
242242
}
243243

244+
std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
245+
if (!tensor.has_names()) {
246+
return {};
247+
}
248+
std::vector<Dimname> outnames;
249+
auto tensor_names = tensor.names();
250+
for (const auto d : c10::irange(tensor.dim())) {
251+
if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
252+
outnames.push_back(tensor_names[d]);
253+
}
254+
}
255+
return outnames;
256+
}
257+
244258
std::vector<Dimname> compute_diagonal_outnames(
245259
const Tensor& tensor,
246260
int64_t dim1,

aten/src/ATen/NamedTensorUtils.h

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <ATen/NamedTensor.h>
33
#include <ATen/TensorNames.h>
4+
#include <ATen/WrapDimUtilsMulti.h>
45

56
#include <ATen/core/DimVector.h>
67
#include <ATen/core/Tensor.h>
@@ -144,6 +145,9 @@ TORCH_API std::vector<Dimname> compute_bmm_outnames(
144145
const Tensor& other);
145146

146147
TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
148+
TORCH_API std::vector<Dimname> compute_squeeze_outnames(
149+
const Tensor& tensor,
150+
std::bitset<dim_bitset_size> dims);
147151

148152
std::vector<Dimname> compute_diagonal_outnames(
149153
const Tensor& tensor,

aten/src/ATen/functorch/BatchRulesViews.cpp

+29-21
Original file line numberDiff line numberDiff line change
@@ -239,34 +239,41 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
239239
return std::make_tuple(result, c10::optional<int64_t>(new_batch_idx));
240240
}
241241

242-
std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(const Tensor& self, optional<int64_t> bdim, int64_t dim) {
242+
std::tuple<Tensor, optional<int64_t>> squeeze_dims_batch_rule(
243+
const Tensor& self, optional<int64_t> bdim, IntArrayRef dims) {
243244
TORCH_INTERNAL_ASSERT(bdim.has_value());
244245
// Special case for scalar arrays to replicate PyTorch behavior.
245-
if (self.dim() == 1) {
246-
TORCH_CHECK(dim == 0, "Dimension is out of range (expected to be in range of [-1, 0], but got ", dim);
246+
auto ndim = self.dim();
247+
if (ndim == 1) {
248+
TORCH_CHECK(
249+
dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
250+
"Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
247251
return std::make_tuple(self.alias(), bdim);
248252
}
249253

250-
// Calculate the proper offset if dim is negative.
251-
auto actual_dim = dim;
252-
if (dim < 0) {
253-
actual_dim = self.dim() + dim - 1;
254-
}
255-
if (actual_dim < bdim) {
256-
// Since dimension to be squeezed is before the batch dimension pass as-is.
257-
auto original_size = self.dim();
258-
auto result = self.squeeze(actual_dim);
259-
auto updated_batch_idx = *bdim;
260-
if (result.dim() != original_size) {
261-
// A column before batch dimension has been dropped so adjust accordingly.
262-
--updated_batch_idx;
254+
// Adjust any dimensions higher than the batch dimension
255+
DimVector adjusted_dims(dims.begin(), dims.end());
256+
int64_t updated_batch_idx = *bdim;
257+
for (auto &d : adjusted_dims) {
258+
auto actual_dim = c10::maybe_wrap_dim(d, ndim - 1);
259+
if (actual_dim < *bdim) {
260+
d = actual_dim;
261+
if (self.sym_size(actual_dim) == 1) {
262+
// A column before batch dimension will be dropped so adjust accordingly.
263+
--updated_batch_idx;
264+
}
265+
} else {
266+
// Since dimension to be squeezed is after the batch dimension adjust by one to account
267+
// for the original batch dimension. In this case batch dimension won't move.
268+
d = actual_dim + 1;
263269
}
264-
return std::make_tuple(result, optional<int64_t>(updated_batch_idx));
265-
} else {
266-
// Since dimension to be squeezed is after the batch dimension adjust by one to account
267-
// for the original batch dimension. In this case batch dimension won't move.
268-
return std::make_tuple(self.squeeze(actual_dim + 1), bdim);
269270
}
271+
return std::make_tuple(self.squeeze(adjusted_dims), optional<int64_t>(updated_batch_idx));
272+
}
273+
274+
std::tuple<Tensor, optional<int64_t>> squeeze_dim_batch_rule(
275+
const Tensor& self, optional<int64_t> bdim, int64_t dim) {
276+
return squeeze_dims_batch_rule(self, bdim, {dim});
270277
}
271278

272279
std::tuple<std::vector<Tensor>, optional<int64_t>> chunk_batching_rule(const Tensor& self, optional<int64_t> self_bdim, int64_t chunks, int64_t dim) {
@@ -547,6 +554,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
547554
VMAP_SUPPORT2(select, int, select_batching_rule);
548555
VMAP_SUPPORT(squeeze, squeeze_batch_rule);
549556
VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
557+
VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
550558
VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
551559
VMAP_SUPPORT(roll, roll_batch_rule);
552560
VMAP_SUPPORT(permute, permute_batching_rule);

aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp

+29-16
Original file line numberDiff line numberDiff line change
@@ -144,40 +144,52 @@ std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntAr
144144
return result;
145145
}
146146

147-
Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
147+
Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
148148
if (!participatesInCurrentLevel(self)) {
149149
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
150-
return self.squeeze_(dim);
150+
return self.squeeze_(dims);
151151
}
152152
auto* batched = maybeGetBatchedImpl(self);
153153
const auto bdim = batched->bdim();
154154
auto logical_dim = self.dim();
155155

156-
// If logically a scalar tensor, then Tensor.squeeze_(dim) is a no-op
157156
if (logical_dim == 0) {
157+
TORCH_CHECK(
158+
dims.size() == 0 || (dims.size() == 1 && dims[0] == 0),
159+
"Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
158160
return self;
159161
}
160162

161-
dim = maybe_wrap_dim(dim, logical_dim);
162-
if (dim >= bdim) {
163-
dim = dim + 1;
164-
batched->value().squeeze_(dim);
165-
batched->refreshTensorMetadata();
166-
return self;
163+
// Adjust any dimensions higher than the batch dimension
164+
DimVector adjusted_dims(dims.begin(), dims.end());
165+
int64_t updated_batch_idx = bdim;
166+
for (auto &d : adjusted_dims) {
167+
auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
168+
if (actual_dim < bdim) {
169+
d = actual_dim;
170+
if (batched->value().sym_size(actual_dim) == 1) {
171+
// A column before batch dimension will be dropped so adjust accordingly.
172+
--updated_batch_idx;
173+
}
174+
} else {
175+
// Since dimension to be squeezed is after the batch dimension adjust by one to account
176+
// for the original batch dimension. In this case batch dimension won't move.
177+
d = actual_dim + 1;
178+
}
167179
}
168180

169-
// Tensor.squeeze_(0) is a no-op if dim 0 has a size other than 1
170-
if (batched->value().size(dim) != 1) {
171-
return self;
181+
batched->value().squeeze_(adjusted_dims);
182+
if (updated_batch_idx != bdim) {
183+
batched->unsafe_set_bdim(updated_batch_idx);
172184
}
173-
174-
// dim < bdim, so we need to adjust bdim
175-
batched->value().squeeze_(dim);
176-
batched->unsafe_set_bdim(bdim - 1);
177185
batched->refreshTensorMetadata();
178186
return self;
179187
}
180188

189+
Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
190+
return squeeze_dims__batching_rule(self, {dim});
191+
}
192+
181193
Tensor& squeeze__batching_rule(Tensor& self) {
182194
if (!participatesInCurrentLevel(self)) {
183195
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
@@ -816,6 +828,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
816828
// still legacy b/c needs special inplace rules
817829
m.impl("squeeze_", squeeze__batching_rule);
818830
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
831+
m.impl("squeeze_.dims", squeeze_dims__batching_rule);
819832
m.impl("unsqueeze_", unsqueeze__batching_rule);
820833
m.impl("transpose_", transpose__batching_rule);
821834

aten/src/ATen/native/ReduceOps.cpp

+2-13
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
#include <ATen/ops/slice.h>
9191
#include <ATen/ops/special_logsumexp_native.h>
9292
#include <ATen/ops/sqrt.h>
93+
#include <ATen/ops/squeeze.h>
9394
#include <ATen/ops/stack.h>
9495
#include <ATen/ops/std.h>
9596
#include <ATen/ops/std_mean.h>
@@ -1381,23 +1382,11 @@ Tensor nanmean(
13811382
return at::nansum(self, dim, keepdim, opt_dtype).div(factor);
13821383
}
13831384

1384-
static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
1385-
int ndims = self.sizes().size();
1386-
auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
1387-
Tensor result = self;
1388-
for (int i = ndims - 1; i >= 0; --i) {
1389-
if (dims_to_squeeze[i]) {
1390-
result = result.squeeze(i);
1391-
}
1392-
}
1393-
return result;
1394-
}
1395-
13961385
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
13971386
// can't take max of empty tensor
13981387
if (self.numel() != 0) {
13991388
auto maxes = at::amax(self, dims, true);
1400-
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
1389+
auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
14011390
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
14021391
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
14031392
result.log_().add_(maxes_squeezed);

0 commit comments

Comments
 (0)