From a0140edf52729c67fd2694640a86ee793bf6c019 Mon Sep 17 00:00:00 2001 From: warrentdrew Date: Mon, 19 May 2025 20:01:22 +0800 Subject: [PATCH 1/2] conv3d support bf16 infer --- backends/npu/kernels/conv_kernel.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backends/npu/kernels/conv_kernel.cc b/backends/npu/kernels/conv_kernel.cc index c3ab58668..83f45a9b5 100644 --- a/backends/npu/kernels/conv_kernel.cc +++ b/backends/npu/kernels/conv_kernel.cc @@ -537,6 +537,9 @@ void Conv3dKernel(const Context& dev_ctx, dilations_vec[4] = dilations[2]; auto stream = dev_ctx.stream(); + if (!FLAGS_npu_jit_compile) { + aclSetCompileopt(ACL_OP_JIT_COMPILE, "enable"); + } const auto& runner = NpuOpRunner("Conv3D", {input_tensor, filter_tensor}, {output_tensor}, @@ -546,6 +549,9 @@ void Conv3dKernel(const Context& dev_ctx, {"groups", groups}, {"data_format", data_format}}); runner.Run(stream); + if (!FLAGS_npu_jit_compile) { + aclSetCompileopt(ACL_OP_JIT_COMPILE, "disable"); + } } template @@ -709,7 +715,8 @@ PD_REGISTER_PLUGIN_KERNEL(conv3d, ALL_LAYOUT, custom_kernel::Conv3dKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_PLUGIN_KERNEL(conv3d_grad, npu, From aed29cccb3d20c26678cf5d8cada68cf29a0e05f Mon Sep 17 00:00:00 2001 From: warrentdrew Date: Wed, 21 May 2025 15:26:56 +0800 Subject: [PATCH 2/2] fix format --- backends/npu/kernels/conv_kernel.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/npu/kernels/conv_kernel.cc b/backends/npu/kernels/conv_kernel.cc index 83f45a9b5..1cce64e32 100644 --- a/backends/npu/kernels/conv_kernel.cc +++ b/backends/npu/kernels/conv_kernel.cc @@ -538,7 +538,7 @@ void Conv3dKernel(const Context& dev_ctx, auto stream = dev_ctx.stream(); if (!FLAGS_npu_jit_compile) { - aclSetCompileopt(ACL_OP_JIT_COMPILE, "enable"); + aclSetCompileopt(ACL_OP_JIT_COMPILE, "enable"); } const auto& runner = NpuOpRunner("Conv3D", {input_tensor, filter_tensor}, @@ -550,7 +550,7 @@ void Conv3dKernel(const Context& dev_ctx, {"data_format", data_format}}); runner.Run(stream); if (!FLAGS_npu_jit_compile) { - aclSetCompileopt(ACL_OP_JIT_COMPILE, "disable"); + aclSetCompileopt(ACL_OP_JIT_COMPILE, "disable"); } }