Skip to content

Commit df58020

Browse files
deandyupytorchmergebot
authored andcommitted
Align max_pool1d Error Checking between CPU and CUDA/CPU requires_grad (pytorch#90211)
Fixes pytorch#85712 Standardizes error checking for max_pool1d between CPU and CPU requires_grad/CUDA. Pull Request resolved: pytorch#90211 Approved by: https://github.com/mruberry
1 parent 3859aac commit df58020

File tree

2 files changed

+46
-35
lines changed

2 files changed

+46
-35
lines changed

aten/src/ATen/native/MaxPooling.cpp

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@ DEFINE_DISPATCH(max_pool1d_stub);
2424

2525
namespace {
2626

27-
Tensor max_pool1d_impl(
27+
static void check_max_pool1d(
2828
const Tensor& self,
2929
IntArrayRef kernel_size,
3030
IntArrayRef stride,
3131
IntArrayRef padding,
3232
IntArrayRef dilation,
3333
bool ceil_mode) {
34-
NoNamesGuard guard;
3534

3635
TORCH_CHECK(
3736
self.dim() == 2 || self.dim() == 3,
@@ -58,33 +57,54 @@ Tensor max_pool1d_impl(
5857
stride = kernel_size;
5958
}
6059

61-
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
62-
const int64_t NC = self.size(-2);
63-
const int64_t IW = self.size(-1);
64-
const int64_t KW = kernel_size[0];
65-
const int64_t SJ = stride[0];
66-
const int64_t PJ = padding[0];
67-
const int64_t DJ = dilation[0];
68-
6960
TORCH_CHECK(
70-
KW > 0,
61+
kernel_size[0] > 0,
7162
"max_pool1d() kernel_size must be greater than zero, but got ",
72-
KW);
63+
kernel_size[0]);
7364
TORCH_CHECK(
74-
SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ);
65+
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
7566
TORCH_CHECK(
76-
PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ);
67+
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
7768
TORCH_CHECK(
78-
PJ <= KW / 2,
69+
padding[0] <= kernel_size[0] / 2,
7970
"max_pool1d() padding should be at most half of kernel size, but got padding=",
80-
PJ,
71+
padding[0],
8172
" and kernel_size=",
82-
KW);
73+
kernel_size[0]);
8374
TORCH_CHECK(
84-
DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ);
75+
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
8576

86-
const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
77+
const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
8778
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
79+
}
80+
81+
} // namespace
82+
83+
namespace {
84+
85+
Tensor max_pool1d_impl(
86+
const Tensor& self,
87+
IntArrayRef kernel_size,
88+
IntArrayRef stride,
89+
IntArrayRef padding,
90+
IntArrayRef dilation,
91+
bool ceil_mode) {
92+
NoNamesGuard guard;
93+
94+
// If stride=None then set it to kernel_size
95+
if (stride.empty()) {
96+
stride = kernel_size;
97+
}
98+
99+
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
100+
const int64_t NC = self.size(-2);
101+
const int64_t IW = self.size(-1);
102+
const int64_t KW = kernel_size[0];
103+
const int64_t SJ = stride[0];
104+
const int64_t PJ = padding[0];
105+
const int64_t DJ = dilation[0];
106+
107+
const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
88108
Tensor output = at::empty({NB, NC, OW}, self.options());
89109

90110
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
@@ -121,6 +141,8 @@ Tensor max_pool1d(
121141
return at::quantized_max_pool1d(
122142
self, kernel_size, stride, padding, dilation, ceil_mode);
123143
}
144+
145+
check_max_pool1d(self, kernel_size, stride, padding, dilation, ceil_mode);
124146
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
125147
self._fw_grad(/*level */ 0).defined() ||
126148
!self.device().is_cpu() ||

torch/testing/_internal/common_methods_invocations.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3032,40 +3032,29 @@ def error_inputs_max_pool1d(op_info, device, **kwargs):
30323032
error_regex=error_msg)
30333033

30343034
# error inputs for empty input with stride=0
3035-
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
3036-
error_msg = 'stride must be greater than zero, but got 0' if torch.device(
3037-
device).type == 'cpu' and not requires_grad else 'stride should not be zero'
3035+
error_msg = 'stride must be greater than zero, but got 0'
30383036
yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
30393037
error_regex=error_msg)
30403038

30413039
# error inputs for empty input with dilation=0
3042-
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
3043-
error_msg = 'dilation must be greater than zero, but got 0' if torch.device(
3044-
device).type == 'cpu' and not requires_grad else 'dilation should be greater than zero, but got dilation'
3040+
error_msg = 'dilation must be greater than zero, but got 0'
30453041
yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
30463042
kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
30473043
error_regex=error_msg)
30483044

30493045
# error inputs for invalid output size
3050-
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
3051-
error_msg = 'Invalid computed output size: -2' if torch.device(device).type == 'cpu' and not requires_grad \
3052-
else \
3053-
r'Given input size: \(2x1x2\). Calculated output size: \(2x1x-2\). Output size is too small'
3046+
error_msg = 'Invalid computed output size: -2'
30543047
yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
30553048
kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
30563049
error_regex=error_msg)
30573050

30583051
# error inputs when kernel_size=0
3059-
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
3060-
error_msg = 'kernel_size must be greater than zero' if torch.device(
3061-
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
3052+
error_msg = 'kernel_size must be greater than zero'
30623053
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
30633054
error_regex=error_msg)
30643055

30653056
# error inputs for strides > 0
3066-
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
3067-
error_msg = 'stride must be greater than zero' if torch.device(
3068-
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
3057+
error_msg = 'stride must be greater than zero'
30693058
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
30703059
error_regex=error_msg)
30713060

0 commit comments

Comments
 (0)