diff --git a/backends/npu/kernels/index_sample_kernel.cc b/backends/npu/kernels/index_sample_kernel.cc index a06c0c508..c7ca2bf2f 100644 --- a/backends/npu/kernels/index_sample_kernel.cc +++ b/backends/npu/kernels/index_sample_kernel.cc @@ -101,6 +101,8 @@ void IndexSampleGather(const Context& dev_ctx, dev_ctx.template Alloc<int64_t>(&tmp_out_t); } else if (tmp_input.dtype() == phi::DataType::FLOAT16) { dev_ctx.template Alloc<phi::dtype::float16>(&tmp_out_t); + } else if (tmp_input.dtype() == phi::DataType::BFLOAT16) { + dev_ctx.template Alloc<phi::dtype::bfloat16>(&tmp_out_t); } NpuOpRunner gather_runner; @@ -138,6 +140,9 @@ void IndexSampleGather(const Context& dev_ctx, } else if (dtype == phi::DataType::FLOAT16) { custom_kernel::ConcatKernel<phi::dtype::float16, Context>( dev_ctx, concat_input, axis, out); + } else if (dtype == phi::DataType::BFLOAT16) { + custom_kernel::ConcatKernel<phi::dtype::bfloat16, Context>( + dev_ctx, concat_input, axis, out); } out->Resize(out_dim); @@ -169,6 +174,9 @@ void IndexSampleGather(const Context& dev_ctx, } else if (dtype == phi::DataType::FLOAT16) { custom_kernel::GatherNdKernel<phi::dtype::float16, Context>( dev_ctx, *input, gather_index, out); + } else if (dtype == phi::DataType::BFLOAT16) { + custom_kernel::GatherNdKernel<phi::dtype::bfloat16, Context>( + dev_ctx, *input, gather_index, out); } } } @@ -247,7 +255,8 @@ PD_REGISTER_PLUGIN_KERNEL(index_sample, phi::dtype::float16, float, int32_t, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_PLUGIN_KERNEL(index_sample_grad, npu, @@ -256,4 +265,5 @@ PD_REGISTER_PLUGIN_KERNEL(index_sample_grad, phi::dtype::float16, float, int32_t, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {}