Skip to content

Commit 9c742be

Browse files
salilsdesaifacebook-github-bot
authored andcommitted
[PyTorch Edge][QNNPack] Enable Depthwise Specific Conv3d Kernel for Kernel Size 3x3x3 (pytorch#69315)
Summary: Pull Request resolved: pytorch#69315 Uses kernels and setup modifications from earlier diffs in this stack ghstack-source-id: 146346780 Test Plan: **Correctness** - Test using QNNPack Operator-Level Test: -- Neon Kernel: As in test plan of D32217846, all tests pass -- SSE2 Kernel: ```buck test xplat/caffe2/aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack_test```, all tests pass - Test by Printing Results of Model-Level Test: D32122020 **Performance** *Operator Level tests from convolution.cc in D32217846* ||Before (V23 of D32217846, without newly added kernel)|After (V48 of D31966574, with newly added kernel)| |depthwise 3x3x3 static|184 ms|134 ms| |depthwise 3x3x3 runtime|181 ms|134 ms| |depthwise 3x3x3s2 static|30 ms|22 ms| |depthwise 3x3x3s2 runtime|30 ms|23 ms| |depthwise 3x3x3s1x2 static|97 ms|70 ms| |depthwise 3x3x3s1x2 runtime|96 ms|70 ms| |depthwise 3x3x3s2x1 static|53 ms|38 ms| |depthwise 3x3x3s2x1 runtime|53 ms|38 ms| |depthwise 3x3x3d2 static|104 ms|74 ms| |depthwise 3x3x3d2 runtime|103 ms|75 ms| |depthwise 3x3x3d1x2 static|158 ms|116 ms| |depthwise 3x3x3d1x2 runtime|157 ms|115 ms| |depthwise 3x3x3d2x1 static|120 ms|86 ms| |depthwise 3x3x3d2x1 runtime|120 ms|87 ms| |depthwise 3x3x3 per channel static|182 ms|134 ms| |depthwise 3x3x3 per channel runtime|184 ms|134 ms| |depthwise 3x3x3s2 per channel static|30 ms|22 ms| |depthwise 3x3x3s2 per channel runtime|31 ms|23 ms| |depthwise 3x3x3s1x2 per channel static|95 ms|70 ms| |depthwise 3x3x3s1x2 per channel runtime|95 ms|71 ms| |depthwise 3x3x3s2x1 per channel static|53 ms|39 ms| |depthwise 3x3x3s2x1 per channel runtime|55 ms|39 ms| |depthwise 3x3x3d2 per channel static|105 ms|75 ms| |depthwise 3x3x3d2 per channel runtime|103 ms|75 ms| |depthwise 3x3x3d1x2 per channel static|158 ms|116 ms| |depthwise 3x3x3d1x2 per channel runtime|158 ms|116 ms| |depthwise 3x3x3d2x1 per channel static|118 ms|87 ms| |depthwise 3x3x3d2x1 per channel runtime|119 ms|87 ms| Average Change: -36.96% (Generated with https://www.internalfb.com/intern/anp/view/?id=1371846&revision_id=291376782898627) *Model Level Test on Synthesized Conv3d Model* Model Details: - 21 channels, input size: 9 x 12 x 7, kernel size: 3x3x3 - Config added in D31928710 - Model generated with https://www.internalfb.com/intern/anp/view/?id=1313660&revision_id=248658657303993 ```buck run aibench:run_bench -- -b dw_conv_3d_3x3x3_big_2b.json --platform android/arm64 --framework pytorch --remote --devices Pixel-4a-11-30``` - Before (V23 of D32217846): [0.0935 ms](https://our.intern.facebook.com/intern/aibench/details/768298420366437) - After (V48 of D31966574): [0.0665 ms](https://our.intern.facebook.com/intern/aibench/details/67271954298132) (29% faster) * Model Level Test on Video Model-like Inputs (provided by liyilui) * - D33000199 - 87.5% faster Reviewed By: kimishpatel Differential Revision: D31966574 fbshipit-source-id: 6554a878401c1120054f6b02241456e8fb44b152
1 parent 3d4590d commit 9c742be

File tree

9 files changed

+546
-135
lines changed

9 files changed

+546
-135
lines changed

aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-prepack.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,62 @@ PrePackConvWeights::PrePackConvWeights(
9696
(20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride,
9797
false);
9898
break;
99+
case 27:
100+
pytorch_pack_q8dw_3d_w_dilation(
101+
kernel_depth,
102+
kernel_height,
103+
kernel_width,
104+
groups,
105+
cr,
106+
0,
107+
kernel_depth,
108+
0,
109+
kernel_height,
110+
0,
111+
1,
112+
kernel,
113+
bias,
114+
packed_weights_,
115+
true);
116+
pytorch_pack_q8dw_3d_w_dilation(
117+
kernel_depth,
118+
kernel_height,
119+
kernel_width,
120+
groups,
121+
cr,
122+
0,
123+
kernel_depth,
124+
0,
125+
kernel_height,
126+
1,
127+
2,
128+
kernel,
129+
bias,
130+
(char*)packed_weights_ +
131+
(kernel_depth * kernel_height +
132+
sizeof(int32_t) / sizeof(uint8_t)) *
133+
c_stride,
134+
false);
135+
pytorch_pack_q8dw_3d_w_dilation(
136+
kernel_depth,
137+
kernel_height,
138+
kernel_width,
139+
groups,
140+
cr,
141+
0,
142+
kernel_depth,
143+
0,
144+
kernel_height,
145+
2,
146+
3,
147+
kernel,
148+
bias,
149+
(char*)packed_weights_ +
150+
(2 * kernel_depth * kernel_height +
151+
sizeof(int32_t) / sizeof(uint8_t)) *
152+
c_stride,
153+
false);
154+
break;
99155
default:
100156
PYTORCH_QNNP_UNREACHABLE;
101157
}

aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ static void compute_sum_rows(
205205
block_start);
206206
}
207207

208-
struct q8dwconv_context {
208+
struct q8dwconv2d_context {
209209
size_t groups;
210210
size_t group_stride;
211211
const uint8_t** indirection_buffer;
@@ -218,11 +218,29 @@ struct q8dwconv_context {
218218
size_t output_row_stride;
219219
size_t output_col_increment;
220220
union pytorch_qnnp_conv_quantization_params quantization_params;
221-
const pytorch_q8dwconv_up_ukernel_function unipass_ukernel;
222-
const pytorch_q8dwconv_mp_ukernel_function multipass_ukernel;
221+
const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
222+
const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
223223
};
224-
static void compute_dwconv_unipass(
225-
const struct q8dwconv_context context[1],
224+
225+
struct q8dwconv3d_context {
226+
size_t groups;
227+
size_t group_stride;
228+
const uint8_t** indirection_buffer;
229+
size_t indirection_buffer_slice_stride;
230+
size_t indirection_buffer_row_stride;
231+
size_t indirection_buffer_col_stride;
232+
const void* packed_weights;
233+
uint8_t* output;
234+
size_t output_depth;
235+
size_t output_height;
236+
size_t output_width;
237+
size_t output_slice_stride;
238+
union pytorch_qnnp_conv_quantization_params quantization_params;
239+
const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
240+
};
241+
242+
static void compute_dwconv2d_unipass(
243+
const struct q8dwconv2d_context context[1],
226244
size_t image,
227245
size_t output_y) {
228246
const size_t output_height = context->output_height;
@@ -240,8 +258,8 @@ static void compute_dwconv_unipass(
240258
context->output_col_increment,
241259
&context->quantization_params);
242260
}
243-
static void compute_dwconv_multiipass(
244-
const struct q8dwconv_context context[1],
261+
static void compute_dwconv2d_multiipass(
262+
const struct q8dwconv2d_context context[1],
245263
size_t image,
246264
size_t output_y) {
247265
const size_t output_height = context->output_height;
@@ -271,6 +289,40 @@ static void compute_dwconv_multiipass(
271289
#endif
272290
}
273291

292+
static void compute_dwconv3d_multiipass(
293+
const struct q8dwconv3d_context context[1],
294+
size_t image,
295+
size_t output_z) {
296+
const size_t output_depth = context->output_depth;
297+
PYTORCH_QNNP_ALIGN(16)
298+
#ifdef _MSC_VER
299+
int32_t* multipass_acc =
300+
(int32_t*)_malloca(sizeof(int32_t) * context->group_stride);
301+
#else
302+
int32_t multipass_acc[context->group_stride];
303+
#endif
304+
305+
context->multipass_ukernel(
306+
context->groups,
307+
context->output_height,
308+
context->output_width,
309+
context->indirection_buffer +
310+
(image * output_depth + output_z) *
311+
context->indirection_buffer_slice_stride,
312+
context->packed_weights,
313+
multipass_acc,
314+
context->output +
315+
(image * output_depth + output_z) * context->output_slice_stride,
316+
context->indirection_buffer_row_stride,
317+
context->indirection_buffer_col_stride,
318+
0,
319+
&context->quantization_params);
320+
321+
#ifdef _MSC_VER
322+
_freea(multipass_acc);
323+
#endif
324+
}
325+
274326
struct QnnpackDeleter {
275327
void operator()(pytorch_qnnp_operator_t op) {
276328
pytorch_qnnp_delete_operator(op);
@@ -366,7 +418,7 @@ enum pytorch_qnnp_status qnnpackConv(
366418

367419
switch (kernel_size) {
368420
case 9: {
369-
struct q8dwconv_context context = {
421+
struct q8dwconv2d_context context = {
370422
.groups = groups,
371423
.group_stride = group_stride,
372424
.indirection_buffer =
@@ -392,14 +444,14 @@ enum pytorch_qnnp_status qnnpackConv(
392444
};
393445
pthreadpool_compute_2d(
394446
threadpool,
395-
(pthreadpool_function_2d_t)compute_dwconv_unipass,
447+
(pthreadpool_function_2d_t)compute_dwconv2d_unipass,
396448
&context,
397449
batch_size,
398450
convolution->output_height);
399451
break;
400452
}
401453
case 25: {
402-
struct q8dwconv_context context = {
454+
struct q8dwconv2d_context context = {
403455
.groups = groups,
404456
.group_stride = group_stride,
405457
.indirection_buffer =
@@ -425,12 +477,41 @@ enum pytorch_qnnp_status qnnpackConv(
425477
};
426478
pthreadpool_compute_2d(
427479
threadpool,
428-
(pthreadpool_function_2d_t)compute_dwconv_multiipass,
480+
(pthreadpool_function_2d_t)compute_dwconv2d_multiipass,
429481
&context,
430482
batch_size,
431483
convolution->output_height);
432484
break;
433485
}
486+
case 27: {
487+
struct q8dwconv3d_context context = {
488+
.groups = groups,
489+
.group_stride = group_stride,
490+
.indirection_buffer =
491+
(const uint8_t**)convolution->indirection_buffer,
492+
.indirection_buffer_slice_stride =
493+
step_height * convolution->output_height,
494+
.indirection_buffer_row_stride = step_height * sizeof(void*),
495+
.indirection_buffer_col_stride =
496+
kernel_height * kernel_depth * step_width * sizeof(void*),
497+
.packed_weights = packed_weights,
498+
.output = output,
499+
.output_depth = convolution->output_depth,
500+
.output_height = convolution->output_height,
501+
.output_width = convolution->output_width,
502+
.output_slice_stride = convolution->output_height *
503+
convolution->output_width * output_pixel_stride,
504+
.quantization_params = conv_quantization_params,
505+
.multipass_ukernel = pytorch_qnnp_params.q8dw27.mpdw,
506+
};
507+
pthreadpool_compute_2d(
508+
threadpool,
509+
(pthreadpool_function_2d_t)compute_dwconv3d_multiipass,
510+
&context,
511+
batch_size,
512+
convolution->output_depth);
513+
break;
514+
}
434515
default:
435516
PYTORCH_QNNP_UNREACHABLE;
436517
}

0 commit comments

Comments
 (0)