@@ -205,7 +205,7 @@ static void compute_sum_rows(
205
205
block_start);
206
206
}
207
207
208
- struct q8dwconv_context {
208
+ struct q8dwconv2d_context {
209
209
size_t groups;
210
210
size_t group_stride;
211
211
const uint8_t ** indirection_buffer;
@@ -218,11 +218,29 @@ struct q8dwconv_context {
218
218
size_t output_row_stride;
219
219
size_t output_col_increment;
220
220
union pytorch_qnnp_conv_quantization_params quantization_params;
221
- const pytorch_q8dwconv_up_ukernel_function unipass_ukernel;
222
- const pytorch_q8dwconv_mp_ukernel_function multipass_ukernel;
221
+ const pytorch_q8dwconv2d_up_ukernel_function unipass_ukernel;
222
+ const pytorch_q8dwconv2d_mp_ukernel_function multipass_ukernel;
223
223
};
224
- static void compute_dwconv_unipass (
225
- const struct q8dwconv_context context[1 ],
224
+
225
+ struct q8dwconv3d_context {
226
+ size_t groups;
227
+ size_t group_stride;
228
+ const uint8_t ** indirection_buffer;
229
+ size_t indirection_buffer_slice_stride;
230
+ size_t indirection_buffer_row_stride;
231
+ size_t indirection_buffer_col_stride;
232
+ const void * packed_weights;
233
+ uint8_t * output;
234
+ size_t output_depth;
235
+ size_t output_height;
236
+ size_t output_width;
237
+ size_t output_slice_stride;
238
+ union pytorch_qnnp_conv_quantization_params quantization_params;
239
+ const pytorch_q8dwconv3d_mp_ukernel_function multipass_ukernel;
240
+ };
241
+
242
+ static void compute_dwconv2d_unipass (
243
+ const struct q8dwconv2d_context context[1 ],
226
244
size_t image,
227
245
size_t output_y) {
228
246
const size_t output_height = context->output_height ;
@@ -240,8 +258,8 @@ static void compute_dwconv_unipass(
240
258
context->output_col_increment ,
241
259
&context->quantization_params );
242
260
}
243
- static void compute_dwconv_multiipass (
244
- const struct q8dwconv_context context[1 ],
261
+ static void compute_dwconv2d_multiipass (
262
+ const struct q8dwconv2d_context context[1 ],
245
263
size_t image,
246
264
size_t output_y) {
247
265
const size_t output_height = context->output_height ;
@@ -271,6 +289,40 @@ static void compute_dwconv_multiipass(
271
289
#endif
272
290
}
273
291
292
+ static void compute_dwconv3d_multiipass (
293
+ const struct q8dwconv3d_context context[1 ],
294
+ size_t image,
295
+ size_t output_z) {
296
+ const size_t output_depth = context->output_depth ;
297
+ PYTORCH_QNNP_ALIGN (16 )
298
+ #ifdef _MSC_VER
299
+ int32_t * multipass_acc =
300
+ (int32_t *)_malloca (sizeof (int32_t ) * context->group_stride );
301
+ #else
302
+ int32_t multipass_acc[context->group_stride ];
303
+ #endif
304
+
305
+ context->multipass_ukernel (
306
+ context->groups ,
307
+ context->output_height ,
308
+ context->output_width ,
309
+ context->indirection_buffer +
310
+ (image * output_depth + output_z) *
311
+ context->indirection_buffer_slice_stride ,
312
+ context->packed_weights ,
313
+ multipass_acc,
314
+ context->output +
315
+ (image * output_depth + output_z) * context->output_slice_stride ,
316
+ context->indirection_buffer_row_stride ,
317
+ context->indirection_buffer_col_stride ,
318
+ 0 ,
319
+ &context->quantization_params );
320
+
321
+ #ifdef _MSC_VER
322
+ _freea (multipass_acc);
323
+ #endif
324
+ }
325
+
274
326
struct QnnpackDeleter {
275
327
void operator ()(pytorch_qnnp_operator_t op) {
276
328
pytorch_qnnp_delete_operator (op);
@@ -366,7 +418,7 @@ enum pytorch_qnnp_status qnnpackConv(
366
418
367
419
switch (kernel_size) {
368
420
case 9 : {
369
- struct q8dwconv_context context = {
421
+ struct q8dwconv2d_context context = {
370
422
.groups = groups,
371
423
.group_stride = group_stride,
372
424
.indirection_buffer =
@@ -392,14 +444,14 @@ enum pytorch_qnnp_status qnnpackConv(
392
444
};
393
445
pthreadpool_compute_2d (
394
446
threadpool,
395
- (pthreadpool_function_2d_t )compute_dwconv_unipass ,
447
+ (pthreadpool_function_2d_t )compute_dwconv2d_unipass ,
396
448
&context,
397
449
batch_size,
398
450
convolution->output_height );
399
451
break ;
400
452
}
401
453
case 25 : {
402
- struct q8dwconv_context context = {
454
+ struct q8dwconv2d_context context = {
403
455
.groups = groups,
404
456
.group_stride = group_stride,
405
457
.indirection_buffer =
@@ -425,12 +477,41 @@ enum pytorch_qnnp_status qnnpackConv(
425
477
};
426
478
pthreadpool_compute_2d (
427
479
threadpool,
428
- (pthreadpool_function_2d_t )compute_dwconv_multiipass ,
480
+ (pthreadpool_function_2d_t )compute_dwconv2d_multiipass ,
429
481
&context,
430
482
batch_size,
431
483
convolution->output_height );
432
484
break ;
433
485
}
486
+ case 27 : {
487
+ struct q8dwconv3d_context context = {
488
+ .groups = groups,
489
+ .group_stride = group_stride,
490
+ .indirection_buffer =
491
+ (const uint8_t **)convolution->indirection_buffer ,
492
+ .indirection_buffer_slice_stride =
493
+ step_height * convolution->output_height ,
494
+ .indirection_buffer_row_stride = step_height * sizeof (void *),
495
+ .indirection_buffer_col_stride =
496
+ kernel_height * kernel_depth * step_width * sizeof (void *),
497
+ .packed_weights = packed_weights,
498
+ .output = output,
499
+ .output_depth = convolution->output_depth ,
500
+ .output_height = convolution->output_height ,
501
+ .output_width = convolution->output_width ,
502
+ .output_slice_stride = convolution->output_height *
503
+ convolution->output_width * output_pixel_stride,
504
+ .quantization_params = conv_quantization_params,
505
+ .multipass_ukernel = pytorch_qnnp_params.q8dw27 .mpdw ,
506
+ };
507
+ pthreadpool_compute_2d (
508
+ threadpool,
509
+ (pthreadpool_function_2d_t )compute_dwconv3d_multiipass,
510
+ &context,
511
+ batch_size,
512
+ convolution->output_depth );
513
+ break ;
514
+ }
434
515
default :
435
516
PYTORCH_QNNP_UNREACHABLE;
436
517
}
0 commit comments