Skip to content

Commit 7b8f73d

Browse files
jbschlosserfacebook-github-bot
authored andcommitted
No-batch-dim support for ConvNd (pytorch#70506)
Summary: Pull Request resolved: pytorch#70506 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D33355034 Pulled By: jbschlosser fbshipit-source-id: 5a42645299b1d82cee7d461826acca1c5b35a71c
1 parent 6896b2d commit 7b8f73d

File tree

7 files changed

+318
-99
lines changed

7 files changed

+318
-99
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,30 @@ static void check_shape_backward(
575575
check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params);
576576
}
577577

578+
// Given an input tensor and an expected number of spatial dimensions, checks that the
579+
// input is a valid shape and returns the batched form of the input.
580+
//
581+
// Args:
582+
// input (Tensor): Input tensor
583+
// num_spatial_dims (int): Number of spatial dimensions expected for the input
584+
// func_name (string): Function name to produce a nice error message for invalid input
585+
//
586+
// Returns a std::tuple containing:
587+
// batched_input (Tensor): Input with a batch dimension
588+
// is_batched (bool): Indicates whether the original input was already batched
589+
static std::tuple<Tensor, bool> batchify(
590+
const Tensor& input,
591+
const int64_t num_spatial_dims,
592+
const std::string& func_name) {
593+
const auto dim_count_no_batch = num_spatial_dims + 1;
594+
const auto dim_count_batch = dim_count_no_batch + 1;
595+
const auto is_batched = (input.dim() == dim_count_batch);
596+
TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
597+
"Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
598+
"D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
599+
return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
600+
}
601+
578602
static void check_input_same_type_as_parameters(
579603
const Tensor& input,
580604
const Tensor& weight,
@@ -618,36 +642,45 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
618642

619643

620644
at::Tensor conv1d(
621-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
645+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
622646
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
623647
// See [Note: hacky wrapper removal for optional tensor]
624648
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
625649
const Tensor& bias = *bias_maybe_owned;
626650

627-
return at::convolution(input, weight, bias, stride, padding, dilation,
628-
false, {0}, groups);
651+
Tensor input;
652+
bool is_batched;
653+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
654+
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
655+
return is_batched ? output : output.squeeze(0);
629656
}
630657

631658
at::Tensor conv2d(
632-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
659+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
633660
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
634661
// See [Note: hacky wrapper removal for optional tensor]
635662
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
636663
const Tensor& bias = *bias_maybe_owned;
637664

638-
return at::convolution(input, weight, bias, stride, padding, dilation,
639-
false, {{0, 0}}, groups);
665+
Tensor input;
666+
bool is_batched;
667+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
668+
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
669+
return is_batched ? output : output.squeeze(0);
640670
}
641671

642672
at::Tensor conv3d(
643-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
673+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
644674
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
645675
// See [Note: hacky wrapper removal for optional tensor]
646676
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
647677
const Tensor& bias = *bias_maybe_owned;
648678

649-
return at::convolution(input, weight, bias, stride, padding, dilation,
650-
false, {{0, 0, 0}}, groups);
679+
Tensor input;
680+
bool is_batched;
681+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
682+
auto output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
683+
return is_batched ? output : output.squeeze(0);
651684
}
652685

653686

@@ -736,60 +769,84 @@ Tensor _convolution_mode(
736769
}
737770

738771
at::Tensor conv1d(
739-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
772+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
740773
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
741774
int64_t groups) {
742-
return at::_convolution_mode(
775+
Tensor input;
776+
bool is_batched;
777+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
778+
auto output = at::_convolution_mode(
743779
input, weight, bias, stride, std::move(padding), dilation, groups);
780+
return is_batched ? output : output.squeeze(0);
744781
}
745782

746783
at::Tensor conv2d(
747-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
784+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
748785
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
749786
int64_t groups) {
750-
return at::_convolution_mode(
787+
Tensor input;
788+
bool is_batched;
789+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
790+
auto output = at::_convolution_mode(
751791
input, weight, bias, stride, std::move(padding), dilation, groups);
792+
return is_batched ? output : output.squeeze(0);
752793
}
753794

754795
at::Tensor conv3d(
755-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias,
796+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias,
756797
IntArrayRef stride, c10::string_view padding, IntArrayRef dilation,
757798
int64_t groups) {
758-
return at::_convolution_mode(
799+
Tensor input;
800+
bool is_batched;
801+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
802+
auto output = at::_convolution_mode(
759803
input, weight, bias, stride, std::move(padding), dilation, groups);
804+
return is_batched ? output : output.squeeze(0);
760805
}
761806

762807
at::Tensor conv_transpose1d(
763-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
808+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
764809
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
765810
// See [Note: hacky wrapper removal for optional tensor]
766811
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
767812
const Tensor& bias = *bias_maybe_owned;
768813

769-
return at::convolution(input, weight, bias, stride, padding, dilation,
770-
true, output_padding, groups);
814+
Tensor input;
815+
bool is_batched;
816+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
817+
auto output = at::convolution(
818+
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
819+
return is_batched ? output : output.squeeze(0);
771820
}
772821

773822
at::Tensor conv_transpose2d(
774-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
823+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
775824
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
776825
// See [Note: hacky wrapper removal for optional tensor]
777826
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
778827
const Tensor& bias = *bias_maybe_owned;
779828

780-
return at::convolution(input, weight, bias, stride, padding, dilation,
781-
true, output_padding, groups);
829+
Tensor input;
830+
bool is_batched;
831+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
832+
auto output = at::convolution(
833+
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
834+
return is_batched ? output : output.squeeze(0);
782835
}
783836

784837
at::Tensor conv_transpose3d(
785-
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
838+
const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
786839
IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) {
787840
// See [Note: hacky wrapper removal for optional tensor]
788841
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
789842
const Tensor& bias = *bias_maybe_owned;
790843

791-
return at::convolution(input, weight, bias, stride, padding, dilation,
792-
true, output_padding, groups);
844+
Tensor input;
845+
bool is_batched;
846+
std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
847+
auto output = at::convolution(
848+
input, weight, bias, stride, padding, dilation, true, output_padding, groups);
849+
return is_batched ? output : output.squeeze(0);
793850
}
794851

795852
at::Tensor convolution(

test/test_cpp_extensions_aot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_conv_backend_override(self):
159159
bias = torch.empty(6, device='ort')
160160

161161
# Make sure forward is overriden
162-
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
162+
out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
163163
self.assertEqual(ort_extension.get_test_int(), 2)
164164
self.assertEqual(out.shape[0], input.shape[0])
165165
self.assertEqual(out.shape[1], weight.shape[0])

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13005,7 +13005,7 @@ def forward(self, x):
1300513005
return self.conv(x)
1300613006
foo = Foo()
1300713007
# testing that the correct error message propagates
13008-
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
13008+
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
1300913009
foo(torch.ones([123])) # wrong size
1301013010

1301113011
def test_builtin_error_messsage(self):

test/test_modules.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.testing._internal.common_modules import module_db, modules
1313
from torch.testing._internal.common_utils import (
1414
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
15-
from unittest.mock import patch
15+
from unittest.mock import patch, call
1616

1717

1818
class TestModule(TestCase):
@@ -122,9 +122,9 @@ def test_factory_kwargs(self, device, dtype, module_info):
122122
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
123123
m = module_cls(*args, **kwargs)
124124
uninit_param_new.mock.assert_has_calls(
125-
[mock.call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
125+
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
126126
uninit_buffer_new.mock.assert_has_calls(
127-
[mock.call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
127+
[call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
128128
else:
129129
# Check device placement and dtype for created parameters and buffers.
130130
# Only verify floating point dtypes since that's what the kwarg applies to.
@@ -421,9 +421,13 @@ def _test_gradients_helper(self, device, dtype, module_info, check):
421421

422422
params = tuple(m.parameters())
423423

424-
# === Perform gradient check on the input_args ===
424+
# === Lazy modules need to see an input to initialize params before gradcheck is run. ===
425425
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
426+
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
427+
with torch.no_grad():
428+
m(*input_args, **input_kwargs)
426429

430+
# === Perform gradient check on the input_args ===
427431
other_kwargs = {}
428432
kwarg_tensors = []
429433
for name, obj in input_kwargs.items():

test/test_nn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -895,8 +895,8 @@ def test_mismatch_shape_conv2d(self):
895895
w = torch.randn(6, 1, 5, 5)
896896

897897
with self.assertRaisesRegex(RuntimeError,
898-
r'Expected 4-dimensional input for 4-dimensional weight \[6, 1, 5, 5\],' +
899-
r' but got 5-dimensional input of size \[1, 10, 1, 28, 28\] instead'):
898+
r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' +
899+
r'input of size: \[1, 10, 1, 28, 28\]'):
900900

901901
F.conv2d(x, w)
902902

@@ -6172,9 +6172,9 @@ def test_conv_modules_raise_error_on_incorrect_input_size(self):
61726172
nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype),
61736173
nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)]
61746174

6175-
invalid_input_dims = [(2, 4), (2, 4),
6176-
(3, 5), (3, 5),
6177-
(4, 6), (4, 6)]
6175+
invalid_input_dims = [(1, 4), (1, 4),
6176+
(2, 5), (2, 5),
6177+
(3, 6), (3, 6)]
61786178

61796179
for invalid_dims, module in zip(invalid_input_dims, modules):
61806180
for dims in invalid_dims:
@@ -13402,7 +13402,7 @@ def test_conv2d_same_padding_backward(self, device):
1340213402
gx_expect, gy_expect = x.grad, y.grad
1340313403
x.grad, y.grad = None, None
1340413404

13405-
z = F.conv1d(x, y, padding='same')
13405+
z = F.conv2d(x, y, padding='same')
1340613406
z.sum().backward()
1340713407
self.assertEqual(gx_expect, x.grad)
1340813408
self.assertEqual(gy_expect, y.grad)

0 commit comments

Comments
 (0)