@@ -239,34 +239,41 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
239
239
return std::make_tuple (result, c10::optional<int64_t >(new_batch_idx));
240
240
}
241
241
242
- std::tuple<Tensor, optional<int64_t >> squeeze_dim_batch_rule (const Tensor& self, optional<int64_t > bdim, int64_t dim) {
242
+ std::tuple<Tensor, optional<int64_t >> squeeze_dims_batch_rule (
243
+ const Tensor& self, optional<int64_t > bdim, IntArrayRef dims) {
243
244
TORCH_INTERNAL_ASSERT (bdim.has_value ());
244
245
// Special case for scalar arrays to replicate PyTorch behavior.
245
- if (self.dim () == 1 ) {
246
- TORCH_CHECK (dim == 0 , " Dimension is out of range (expected to be in range of [-1, 0], but got " , dim);
246
+ auto ndim = self.dim ();
247
+ if (ndim == 1 ) {
248
+ TORCH_CHECK (
249
+ dims.size () == 0 || (dims.size () == 1 && dims[0 ] == 0 ),
250
+ " Dimension is out of range (expected to be in range of [-1, 0], but got " , dims);
247
251
return std::make_tuple (self.alias (), bdim);
248
252
}
249
253
250
- // Calculate the proper offset if dim is negative.
251
- auto actual_dim = dim;
252
- if (dim < 0 ) {
253
- actual_dim = self.dim () + dim - 1 ;
254
- }
255
- if (actual_dim < bdim) {
256
- // Since dimension to be squeezed is before the batch dimension pass as-is.
257
- auto original_size = self.dim ();
258
- auto result = self.squeeze (actual_dim);
259
- auto updated_batch_idx = *bdim;
260
- if (result.dim () != original_size) {
261
- // A column before batch dimension has been dropped so adjust accordingly.
262
- --updated_batch_idx;
254
+ // Adjust any dimensions higher than the batch dimension
255
+ DimVector adjusted_dims (dims.begin (), dims.end ());
256
+ int64_t updated_batch_idx = *bdim;
257
+ for (auto &d : adjusted_dims) {
258
+ auto actual_dim = c10::maybe_wrap_dim (d, ndim - 1 );
259
+ if (actual_dim < *bdim) {
260
+ d = actual_dim;
261
+ if (self.sym_size (actual_dim) == 1 ) {
262
+ // A column before batch dimension will be dropped so adjust accordingly.
263
+ --updated_batch_idx;
264
+ }
265
+ } else {
266
+ // Since dimension to be squeezed is after the batch dimension adjust by one to account
267
+ // for the original batch dimension. In this case batch dimension won't move.
268
+ d = actual_dim + 1 ;
263
269
}
264
- return std::make_tuple (result, optional<int64_t >(updated_batch_idx));
265
- } else {
266
- // Since dimension to be squeezed is after the batch dimension adjust by one to account
267
- // for the original batch dimension. In this case batch dimension won't move.
268
- return std::make_tuple (self.squeeze (actual_dim + 1 ), bdim);
269
270
}
271
+ return std::make_tuple (self.squeeze (adjusted_dims), optional<int64_t >(updated_batch_idx));
272
+ }
273
+
274
+ std::tuple<Tensor, optional<int64_t >> squeeze_dim_batch_rule (
275
+ const Tensor& self, optional<int64_t > bdim, int64_t dim) {
276
+ return squeeze_dims_batch_rule (self, bdim, {dim});
270
277
}
271
278
272
279
std::tuple<std::vector<Tensor>, optional<int64_t >> chunk_batching_rule (const Tensor& self, optional<int64_t > self_bdim, int64_t chunks, int64_t dim) {
@@ -547,6 +554,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
547
554
VMAP_SUPPORT2 (select , int , select_batching_rule);
548
555
VMAP_SUPPORT (squeeze, squeeze_batch_rule);
549
556
VMAP_SUPPORT2 (squeeze, dim, squeeze_dim_batch_rule);
557
+ VMAP_SUPPORT2 (squeeze, dims, squeeze_dims_batch_rule);
550
558
VMAP_SUPPORT (_reshape_alias, _reshape_alias_batch_rule);
551
559
VMAP_SUPPORT (roll, roll_batch_rule);
552
560
VMAP_SUPPORT (permute, permute_batching_rule);
0 commit comments