@@ -80,6 +80,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
8080 reg64_t reg_ow_pos = rdx;
8181 reg64_t aux_reg_output = reg_ow_pos;
8282 reg64_t reg_dg_iter = reg_output;
83+ reg64_t reg_gr_iter = rsp;
8384 reg64_t aux_reg_input = rax;
8485 reg64_t aux2_reg_input = reg_kernel;
8586 reg64_t reg_ic_iter = rbx;
@@ -853,7 +854,6 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
853854 L (oc_unrolled_loop); {
854855 cmp (reg_oc_work, jcp_.nb_oc_blocking * jcp_.oc_block );
855856 jl (oc_main_loop, T_NEAR);
856-
857857 ic_loop (ow_step, jcp_.nb_oc_blocking , jcp_.oc_block );
858858 store_output (ow_step, jcp_.nb_oc_blocking , jcp_.oc_block );
859859
@@ -869,7 +869,6 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_
869869 L (oc_main_loop); {
870870 cmp (reg_oc_work, jcp_.oc_block );
871871 jl (oc_tail, T_NEAR);
872-
873872 ic_loop (ow_step, 1 , jcp_.oc_block );
874873 store_output (ow_step, 1 , jcp_.oc_block );
875874
@@ -987,17 +986,18 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
987986 config.outConfs [0 ].inPlace = -1 ;
988987
989988 impl_desc_type impl_type;
990- if (mayiuse (cpu::x64::avx512_common)) {
991- impl_type = impl_desc_type::jit_avx512;
992- } else if (mayiuse (cpu::x64::avx2)) {
993- impl_type = impl_desc_type::jit_avx2;
994- } else if (mayiuse (cpu::x64::sse41)) {
995- impl_type = impl_desc_type::jit_sse42;
996- } else {
997- impl_type = impl_desc_type::ref;
998- }
999-
1000- if (mayiuse (cpu::x64::sse41)) {
989+ // if (mayiuse(cpu::x64::avx512_common)) {
990+ // impl_type = impl_desc_type::jit_avx512;
991+ // } else if (mayiuse(cpu::x64::avx2)) {
992+ // impl_type = impl_desc_type::jit_avx2;
993+ // } else if (mayiuse(cpu::x64::sse41)) {
994+ // impl_type = impl_desc_type::jit_sse42;
995+ // } else {
996+ // impl_type = impl_desc_type::ref;
997+ // }
998+ impl_type = impl_desc_type::ref;
999+
1000+ if (false && mayiuse (cpu::x64::sse41)) {
10011001 // optimzed implementation
10021002 auto dataFormat = memory::format_tag::nhwc;
10031003 auto offFormat = memory::format_tag::nchw;
@@ -1062,9 +1062,9 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
10621062 jcp.oh = dstDims[2 ];
10631063 jcp.ow = dstDims[3 ];
10641064
1065- bool with_groups = group > 1 ;
1066- jcp.kh = weiDims[with_groups + 2 ];
1067- jcp.kw = weiDims[with_groups + 3 ];
1065+ // bool with_groups = group > 1;
1066+ jcp.kh = weiDims[2 ];
1067+ jcp.kw = weiDims[3 ];
10681068
10691069 jcp.t_pad = paddingL[0 ];
10701070 jcp.l_pad = paddingL[1 ];
@@ -1097,13 +1097,13 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
10971097
10981098 jcp.nthr = dnnl_get_max_threads ();
10991099
1100- if (mayiuse (cpu::x64::avx512_common)) {
1101- def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1102- } else if (mayiuse (cpu::x64::avx2)) {
1103- def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1104- } else if (mayiuse (cpu::x64::sse41)) {
1105- def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1106- }
1100+ // if (mayiuse(cpu::x64::avx512_common)) {
1101+ // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1102+ // } else if (mayiuse(cpu::x64::avx2)) {
1103+ // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1104+ // } else if (mayiuse(cpu::x64::sse41)) {
1105+ // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1106+ // }
11071107
11081108 if (def_conv_kernel)
11091109 def_conv_kernel->create_ker ();
@@ -1147,7 +1147,7 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
11471147
11481148 for (int ic = 0 ; ic < IC; ic++) {
11491149 const float *data_im_ptr = src + mb * src_strides[0 ] + (g * IC + ic) * src_strides[1 ] + h_in * src_strides[2 ] + w_in * src_strides[3 ];
1150- const int deformable_group_index = ic / channel_per_deformable_group;
1150+ const int deformable_group_index = (IC * g + ic) / channel_per_deformable_group;
11511151 const float *data_offset_ptr = offsets + mb * off_strides[0 ] + (deformable_group_index * 2 * KH * KW) * off_strides[1 ];
11521152 const float *modulation_offset_ptr = nullptr ;
11531153 if (modulation != nullptr ) {
@@ -1165,7 +1165,20 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
11651165
11661166 const float h_im = h_in + map_h; // absolute pixel index with offset
11671167 const float w_im = w_in + map_w; // absolute pixel index with offset
1168- if (h_im >= 0 && w_im >= 0 && h_im < IH && w_im < IW) {
1168+ bool skip_compute;
1169+ if (with_bilinear_pad) {
1170+ skip_compute = !(static_cast <int >(w_im) > -1 &&
1171+ static_cast <int >(w_im) < IW &&
1172+ static_cast <int >(h_im) > -1 &&
1173+ static_cast <int >(h_im) < IH);
1174+ } else {
1175+ skip_compute = !(w_im >= 0 &&
1176+ w_im < IW &&
1177+ h_im >= 0 &&
1178+ h_im < IH);
1179+ }
1180+ if (!skip_compute) {
1181+ // if (h_im >= 0 && w_im >= 0 && h_im < IH && w_im < IW) {
11691182 const int cur_height = IH - h_in;
11701183 const int cur_width = IW - w_in;
11711184 int h_low = std::max (static_cast <int >(floorf (map_h)), 0 );
@@ -1192,8 +1205,8 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
11921205 modulation_scalar = modulation_offset_ptr[modulation_index];
11931206 }
11941207
1195- const float weight = with_groups ? weights[g * wei_strides[ 0 ] + oc * wei_strides[1 ] + ic * wei_strides[2 ] + kh * wei_strides[3 ] +
1196- kw * wei_strides[4 ]]
1208+ const float weight = with_groups ? weights[(g + oc / G) * wei_strides[0 ] + ic * wei_strides[1 ] + kh * wei_strides[2 ] +
1209+ kw * wei_strides[3 ]]
11971210 : weights[oc * wei_strides[0 ] + ic * wei_strides[1 ] + kh * wei_strides[2 ] + kw * wei_strides[3 ]];
11981211 d += val * weight * modulation_scalar;
11991212 }
@@ -1205,7 +1218,7 @@ void MKLDNNDeformableConvolutionNode::executeReference(const float* src, const f
12051218 };
12061219
12071220 parallel_nd (G, MB, OC, OH, OW,
1208- [&](int g, int mb, int oc, int oh, int ow) {
1221+ [&](int g, int mb, int oc, int oh, int ow) {
12091222 dst[mb * dst_strides[0 ] + (g * OC + oc) * dst_strides[1 ] + oh * dst_strides[2 ] + ow * dst_strides[3 ]] = ker (g, mb, oc, oh, ow);
12101223 });
12111224}
0 commit comments