diff --git a/backends/npu/kernels/index_sample_kernel.cc b/backends/npu/kernels/index_sample_kernel.cc
index a06c0c508..c7ca2bf2f 100644
--- a/backends/npu/kernels/index_sample_kernel.cc
+++ b/backends/npu/kernels/index_sample_kernel.cc
@@ -101,6 +101,8 @@ void IndexSampleGather(const Context& dev_ctx,
         dev_ctx.template Alloc<int64_t>(&tmp_out_t);
       } else if (tmp_input.dtype() == phi::DataType::FLOAT16) {
         dev_ctx.template Alloc<phi::dtype::float16>(&tmp_out_t);
+      } else if (tmp_input.dtype() == phi::DataType::BFLOAT16) {
+        dev_ctx.template Alloc<phi::dtype::bfloat16>(&tmp_out_t);
       }
 
       NpuOpRunner gather_runner;
@@ -138,6 +140,9 @@ void IndexSampleGather(const Context& dev_ctx,
     } else if (dtype == phi::DataType::FLOAT16) {
       custom_kernel::ConcatKernel<phi::dtype::float16, Context>(
           dev_ctx, concat_input, axis, out);
+    } else if (dtype == phi::DataType::BFLOAT16) {
+      custom_kernel::ConcatKernel<phi::dtype::bfloat16, Context>(
+          dev_ctx, concat_input, axis, out);
     }
     out->Resize(out_dim);
 
@@ -169,6 +174,9 @@ void IndexSampleGather(const Context& dev_ctx,
     } else if (dtype == phi::DataType::FLOAT16) {
       custom_kernel::GatherNdKernel<phi::dtype::float16, Context>(
           dev_ctx, *input, gather_index, out);
+    } else if (dtype == phi::DataType::BFLOAT16) {
+      custom_kernel::GatherNdKernel<phi::dtype::bfloat16, Context>(
+          dev_ctx, *input, gather_index, out);
     }
   }
 }
@@ -247,7 +255,8 @@ PD_REGISTER_PLUGIN_KERNEL(index_sample,
                           phi::dtype::float16,
                           float,
                           int32_t,
-                          int64_t) {}
+                          int64_t,
+                          phi::dtype::bfloat16) {}
 
 PD_REGISTER_PLUGIN_KERNEL(index_sample_grad,
                           npu,
@@ -256,4 +265,5 @@ PD_REGISTER_PLUGIN_KERNEL(index_sample_grad,
                           phi::dtype::float16,
                           float,
                           int32_t,
-                          int64_t) {}
+                          int64_t,
+                          phi::dtype::bfloat16) {}