Skip to content

Commit 0bab1d8

Browse files
BruceDailina128
andauthored
Expose dimRoundingMode attribute of pool op (tensorflow#5849)
* Expose dimRoundingMode attribute of pool op. * Add conv_util.checkPadOnDimRoundingMode() function with tests. Co-authored-by: Na Li <[email protected]>
1 parent 2dcda6f commit 0bab1d8

29 files changed

+915
-135
lines changed

tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,8 @@ export const depthwiseConv2dNativeGradConfig: GradConfig = {
5858
`dilations must be 1. Got strides ${strides} and dilations ` +
5959
`'${$dilations}'.`);
6060

61-
if (dimRoundingMode != null) {
62-
util.assert(
63-
util.isInt(pad as number),
64-
() =>
65-
`Error in depthwiseConv2d: pad must be an integer when using, ` +
66-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
67-
}
61+
conv_util.checkPadOnDimRoundingMode(
62+
'depthwiseConv2d', pad, dimRoundingMode);
6863

6964
return {
7065
x: () => depthwiseConv2dNativeBackpropInput(

tfjs-core/src/ops/avg_pool.ts

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,8 @@ function avgPool_<T extends Tensor3D|Tensor4D>(
7272
util.assert(
7373
x4D.rank === 4,
7474
() => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);
75-
76-
if (dimRoundingMode != null) {
77-
util.assert(
78-
util.isInt(pad as number),
79-
() => `Error in avgPool: pad must be an integer when using, ` +
80-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
81-
}
82-
75+
conv_util.checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
8376
const inputs: AvgPoolInputs = {x: x4D};
84-
8577
const attrs: AvgPoolAttrs = {filterSize, strides, pad, dimRoundingMode};
8678

8779
// tslint:disable-next-line: no-unnecessary-type-assertion

tfjs-core/src/ops/avg_pool_3d.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env';
2424
import {TensorLike} from '../types';
2525
import * as util from '../util';
2626

27+
import {checkPadOnDimRoundingMode} from './conv_util';
2728
import {cast} from './cast';
2829
import {op} from './operation';
2930
import {reshape} from './reshape';
@@ -85,16 +86,8 @@ function avgPool3d_<T extends Tensor4D|Tensor5D>(
8586
dataFormat === 'NDHWC',
8687
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
8788
`but got dataFormat of ${dataFormat}`);
88-
89-
if (dimRoundingMode != null) {
90-
util.assert(
91-
util.isInt(pad as number),
92-
() => `Error in avgPool3d: pad must be an integer when using, ` +
93-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
94-
}
95-
89+
checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
9690
const inputs: AvgPool3DInputs = {x: x5D};
97-
9891
const attrs:
9992
AvgPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat};
10093

tfjs-core/src/ops/avg_pool_3d_grad.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {convertToTensor} from '../tensor_util_env';
2525
import {TensorLike} from '../types';
2626
import * as util from '../util';
2727

28+
import {checkPadOnDimRoundingMode} from './conv_util';
2829
import {op} from './operation';
2930
import {reshape} from './reshape';
3031

@@ -77,16 +78,8 @@ function avgPool3dGrad_<T extends Tensor4D|Tensor5D>(
7778
input5D.rank === 5,
7879
() => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
7980
`${input5D.rank}.`);
80-
81-
if (dimRoundingMode != null) {
82-
util.assert(
83-
util.isInt(pad as number),
84-
() => `Error in avgPool3dGrad: pad must be an integer when ` +
85-
`using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
86-
}
87-
81+
checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
8882
const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};
89-
9083
const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};
9184

9285
// tslint:disable-next-line: no-unnecessary-type-assertion

tfjs-core/src/ops/avg_pool_3d_test.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,31 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
157157
expect(() => tf.avgPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
158158
});
159159

160-
it('throws when dimRoundingMode is set and pad is not a number', async () => {
160+
it('throws when dimRoundingMode is set and pad is same', async () => {
161+
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
162+
const pad = 'same';
163+
const dimRoundingMode = 'round';
164+
165+
expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
166+
});
167+
168+
it('throws when dimRoundingMode is set and pad is valid', async () => {
161169
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
162170
const pad = 'valid';
163171
const dimRoundingMode = 'round';
164172

165173
expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
166174
});
167175

176+
it('throws when dimRoundingMode is set and pad is a non-integer number',
177+
async () => {
178+
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
179+
const pad = 1.2;
180+
const dimRoundingMode = 'round';
181+
182+
expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
183+
});
184+
168185
it('throws when passed a non-tensor', () => {
169186
expect(() => tf.avgPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
170187
});

tfjs-core/src/ops/avg_pool_test.ts

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
104104
expectArraysClose(await result.data(), [2.5, 3, 4, 4.5, 5.5, 6, 7, 7.5]);
105105
});
106106

107+
it('x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', () => {
108+
// Feed forward.
109+
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
110+
const result = tf.avgPool(x, 2, 3, 1);
111+
112+
expect(result.shape).toEqual([1, 1, 3]);
113+
});
114+
107115
it('x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', () => {
108116
// Feed forward.
109117
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
@@ -112,6 +120,30 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
112120
expect(result.shape).toEqual([2, 2, 3]);
113121
});
114122

123+
it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', () => {
124+
// Feed forward.
125+
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
126+
const result = tf.avgPool(x, 2, 3, 1, 'floor');
127+
128+
expect(result.shape).toEqual([1, 1, 3]);
129+
});
130+
131+
it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', () => {
132+
// Feed forward.
133+
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
134+
const result = tf.avgPool(x, 2, 3, 1, 'round');
135+
136+
expect(result.shape).toEqual([2, 2, 3]);
137+
});
138+
139+
it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', () => {
140+
// Feed forward.
141+
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
142+
const result = tf.avgPool(x, 2, 3, 1, 'ceil');
143+
144+
expect(result.shape).toEqual([2, 2, 3]);
145+
});
146+
115147
it('gradient x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => {
116148
const x = tf.tensor3d([0], [1, 1, 1]);
117149
const dy = tf.tensor3d([0], [1, 1, 1]);
@@ -178,7 +210,16 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
178210
]);
179211
});
180212

181-
it('throws when dimRoundingMode is set and pad is not a number', () => {
213+
it('throws when dimRoundingMode is set and pad is same', () => {
214+
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
215+
216+
const pad = 'same';
217+
const dimRoundingMode = 'round';
218+
219+
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
220+
});
221+
222+
it('throws when dimRoundingMode is set and pad is valid', () => {
182223
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
183224

184225
const pad = 'valid';
@@ -187,6 +228,28 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
187228
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
188229
});
189230

231+
it('throws when dimRoundingMode is set and pad is a non-integer number',
232+
() => {
233+
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
234+
235+
const pad = 1.2;
236+
const dimRoundingMode = 'round';
237+
238+
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
239+
});
240+
241+
it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
242+
'number',
243+
() => {
244+
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);
245+
246+
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
247+
tf.backend_util.ExplicitPadding;
248+
const dimRoundingMode = 'round';
249+
250+
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
251+
});
252+
190253
it('throws when passed a non-tensor', () => {
191254
expect(() => tf.avgPool({} as tf.Tensor3D, 2, 1, 'valid'))
192255
.toThrowError(/Argument 'x' passed to 'avgPool' must be a Tensor/);

tfjs-core/src/ops/conv1d.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,7 @@ function conv1d_<T extends Tensor2D|Tensor3D>(
7474
$filter.rank === 3,
7575
() => `Error in conv1d: filter must be rank 3, but got rank ` +
7676
`${$filter.rank}.`);
77-
if (dimRoundingMode != null) {
78-
util.assert(
79-
util.isInt(pad as number),
80-
() => `Error in conv1d: pad must be an integer when using, ` +
81-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
82-
}
83-
77+
conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
8478
util.assert(
8579
x3D.shape[2] === $filter.shape[1],
8680
() => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +

tfjs-core/src/ops/conv1d_test.ts

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,90 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
134134
expectArraysClose(await result.data(), await expectedResult.data());
135135
});
136136

137+
it('throws when dimRoundingMode is set and pad is same', () => {
138+
const inputDepth = 1;
139+
const inputShape: [number, number, number] = [2, 2, inputDepth];
140+
const outputDepth = 1;
141+
const fSize = 1;
142+
const pad = 'same';
143+
const stride = 1;
144+
const dataFormat = 'NWC';
145+
const dilation = 1;
146+
const dimRoundingMode = 'round';
147+
148+
const x = tf.tensor3d([1, 2, 3, 4], inputShape);
149+
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);
150+
151+
expect(
152+
() => tf.conv1d(
153+
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
154+
.toThrowError();
155+
});
156+
157+
it('throws when dimRoundingMode is set and pad is valid', () => {
158+
const inputDepth = 1;
159+
const inputShape: [number, number, number] = [2, 2, inputDepth];
160+
const outputDepth = 1;
161+
const fSize = 1;
162+
const pad = 'valid';
163+
const stride = 1;
164+
const dataFormat = 'NWC';
165+
const dilation = 1;
166+
const dimRoundingMode = 'round';
167+
168+
const x = tf.tensor3d([1, 2, 3, 4], inputShape);
169+
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);
170+
171+
expect(
172+
() => tf.conv1d(
173+
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
174+
.toThrowError();
175+
});
176+
177+
it('throws when dimRoundingMode is set and pad is a non-integer number',
178+
() => {
179+
const inputDepth = 1;
180+
const inputShape: [number, number, number] = [2, 2, inputDepth];
181+
const outputDepth = 1;
182+
const fSize = 1;
183+
const pad = 1.2;
184+
const stride = 1;
185+
const dataFormat = 'NWC';
186+
const dilation = 1;
187+
const dimRoundingMode = 'round';
188+
189+
const x = tf.tensor3d([1, 2, 3, 4], inputShape);
190+
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);
191+
192+
expect(
193+
() => tf.conv1d(
194+
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
195+
.toThrowError();
196+
});
197+
198+
it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
199+
'number',
200+
() => {
201+
const inputDepth = 1;
202+
const inputShape: [number, number, number] = [2, 2, inputDepth];
203+
const outputDepth = 1;
204+
const fSize = 1;
205+
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
206+
tf.backend_util.ExplicitPadding;
207+
const stride = 1;
208+
const dataFormat = 'NWC';
209+
const dilation = 1;
210+
const dimRoundingMode = 'round';
211+
212+
const x = tf.tensor3d([1, 2, 3, 4], inputShape);
213+
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);
214+
215+
expect(
216+
() => tf.conv1d(
217+
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
218+
.toThrowError();
219+
});
220+
137221
it('TensorLike', async () => {
138222
const pad = 'same';
139223
const stride = 1;

tfjs-core/src/ops/conv2d.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
8484
$filter.rank === 4,
8585
() => `Error in conv2d: filter must be rank 4, but got rank ` +
8686
`${$filter.rank}.`);
87-
if (dimRoundingMode != null) {
88-
util.assert(
89-
util.isInt(pad as number),
90-
() => `Error in conv2d: pad must be an integer when using, ` +
91-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
92-
}
93-
87+
conv_util.checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
9488
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
9589
util.assert(
9690
inDepth === $filter.shape[2],

tfjs-core/src/ops/conv2d_backprop_filter.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,7 @@ function conv2DBackpropFilter_<T extends Tensor3D|Tensor4D>(
8181
outDepth === filterShape[3],
8282
() => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
8383
`match output depth for filter (${filterShape[3]}).`);
84-
if (dimRoundingMode != null) {
85-
util.assert(
86-
util.isInt(pad as number),
87-
() => `Error in conv2dDerFilter: pad must be an integer when using, ` +
88-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
89-
}
90-
84+
conv_util.checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
9185
const inputs: Conv2DBackpropFilterInputs = {x: x4D, dy: dy4D};
9286
const attrs: Conv2DBackpropFilterAttrs =
9387
{strides, pad, dataFormat, dimRoundingMode, filterShape};

tfjs-core/src/ops/conv2d_backprop_input.ts

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,8 @@ function conv2DBackpropInput_<T extends Tensor3D|Tensor4D>(
9292
outDepth === filter.shape[3],
9393
() => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
9494
`match output depth for filter ${filter.shape[3]}.`);
95-
if (dimRoundingMode != null) {
96-
util.assert(
97-
util.isInt(pad as number),
98-
() => `Error in conv2dDerInput: pad must be an integer when using, ` +
99-
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
100-
}
101-
95+
conv_util.checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
10296
const inputs: Conv2DBackpropInputInputs = {dy: dy4D, filter};
103-
10497
const attrs: Conv2DBackpropInputAttrs =
10598
{strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D};
10699

0 commit comments

Comments
 (0)