@@ -24,14 +24,13 @@ DEFINE_DISPATCH(max_pool1d_stub);
24
24
25
25
namespace {
26
26
27
- Tensor max_pool1d_impl (
27
+ static void check_max_pool1d (
28
28
const Tensor& self,
29
29
IntArrayRef kernel_size,
30
30
IntArrayRef stride,
31
31
IntArrayRef padding,
32
32
IntArrayRef dilation,
33
33
bool ceil_mode) {
34
- NoNamesGuard guard;
35
34
36
35
TORCH_CHECK (
37
36
self.dim () == 2 || self.dim () == 3 ,
@@ -58,33 +57,54 @@ Tensor max_pool1d_impl(
58
57
stride = kernel_size;
59
58
}
60
59
61
- const int64_t NB = self.dim () == 3 ? self.size (-3 ) : 1 ;
62
- const int64_t NC = self.size (-2 );
63
- const int64_t IW = self.size (-1 );
64
- const int64_t KW = kernel_size[0 ];
65
- const int64_t SJ = stride[0 ];
66
- const int64_t PJ = padding[0 ];
67
- const int64_t DJ = dilation[0 ];
68
-
69
60
TORCH_CHECK (
70
- KW > 0 ,
61
+ kernel_size[ 0 ] > 0 ,
71
62
" max_pool1d() kernel_size must be greater than zero, but got " ,
72
- KW );
63
+ kernel_size[ 0 ] );
73
64
TORCH_CHECK (
74
- SJ > 0 , " max_pool1d() stride must be greater than zero, but got " , SJ );
65
+ stride[ 0 ] > 0 , " max_pool1d() stride must be greater than zero, but got " , stride[ 0 ] );
75
66
TORCH_CHECK (
76
- PJ >= 0 , " max_pool1d() padding must be non-negative, but got " , PJ );
67
+ padding[ 0 ] >= 0 , " max_pool1d() padding must be non-negative, but got " , padding[ 0 ] );
77
68
TORCH_CHECK (
78
- PJ <= KW / 2 ,
69
+ padding[ 0 ] <= kernel_size[ 0 ] / 2 ,
79
70
" max_pool1d() padding should be at most half of kernel size, but got padding=" ,
80
- PJ ,
71
+ padding[ 0 ] ,
81
72
" and kernel_size=" ,
82
- KW );
73
+ kernel_size[ 0 ] );
83
74
TORCH_CHECK (
84
- DJ > 0 , " max_pool1d() dilation must be greater than zero, but got " , DJ );
75
+ dilation[ 0 ] > 0 , " max_pool1d() dilation must be greater than zero, but got " , dilation[ 0 ] );
85
76
86
- const int64_t OW = pooling_output_shape (IW, KW, PJ, SJ, DJ , ceil_mode);
77
+ const int64_t OW = pooling_output_shape (self. size (- 1 ), kernel_size[ 0 ], padding[ 0 ], stride[ 0 ], dilation[ 0 ] , ceil_mode);
87
78
TORCH_CHECK (OW >= 0 , " max_pool1d() Invalid computed output size: " , OW);
79
+ }
80
+
81
+ } // namespace
82
+
83
+ namespace {
84
+
85
+ Tensor max_pool1d_impl (
86
+ const Tensor& self,
87
+ IntArrayRef kernel_size,
88
+ IntArrayRef stride,
89
+ IntArrayRef padding,
90
+ IntArrayRef dilation,
91
+ bool ceil_mode) {
92
+ NoNamesGuard guard;
93
+
94
+ // If stride=None then set it to kernel_size
95
+ if (stride.empty ()) {
96
+ stride = kernel_size;
97
+ }
98
+
99
+ const int64_t NB = self.dim () == 3 ? self.size (-3 ) : 1 ;
100
+ const int64_t NC = self.size (-2 );
101
+ const int64_t IW = self.size (-1 );
102
+ const int64_t KW = kernel_size[0 ];
103
+ const int64_t SJ = stride[0 ];
104
+ const int64_t PJ = padding[0 ];
105
+ const int64_t DJ = dilation[0 ];
106
+
107
+ const int64_t OW = pooling_output_shape (IW, KW, PJ, SJ, DJ, ceil_mode);
88
108
Tensor output = at::empty ({NB, NC, OW}, self.options ());
89
109
90
110
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
@@ -121,6 +141,8 @@ Tensor max_pool1d(
121
141
return at::quantized_max_pool1d (
122
142
self, kernel_size, stride, padding, dilation, ceil_mode);
123
143
}
144
+
145
+ check_max_pool1d (self, kernel_size, stride, padding, dilation, ceil_mode);
124
146
if ((self.requires_grad () && at::GradMode::is_enabled ()) ||
125
147
self._fw_grad (/* level */ 0 ).defined () ||
126
148
!self.device ().is_cpu () ||
0 commit comments