From d08c52c57f39e294d73436068453cfcfcb40524b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 18 Oct 2025 09:51:37 +0800 Subject: [PATCH] [webgpu] Register GQA based on graph capture --- .../cpu/bert/group_query_attention_helper.h | 28 ++++++------- .../webgpu/bert/flash_attention.cc | 32 ++++++++++----- .../contrib_ops/webgpu/bert/flash_attention.h | 13 ++++--- .../webgpu/bert/flash_attention.wgsl.template | 39 +++++++++++++------ .../webgpu/bert/group_query_attention.cc | 37 +++++++++++------- .../webgpu/bert/group_query_attention.h | 2 + .../webgpu/webgpu_contrib_kernels.cc | 8 +++- .../webgpu/webgpu_contrib_kernels.h | 2 +- .../webgpu/webgpu_execution_provider.cc | 2 +- 9 files changed, 105 insertions(+), 58 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 46d3e7e675e85..01172bb8f3270 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -251,12 +251,14 @@ Status CheckInputs(const T* query, "seqlens_k must be shape (batch_size)."); } - // Set present sequence length from input total_seqlen tensor if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "total_sequence_length tensor must be of one element."); } - int total_sequence_length = *((*total_seqlen).template Data()); + + // When graph capture is enabled, total_seqlen is on GPU and cannot be read. Skip validation. + const bool is_total_seqlen_on_cpu = (total_seqlen->Location().device.Type() == OrtDevice::CPU); + int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data()) : 0; int present_sequence_length = std::max(total_sequence_length, past_sequence_length); int rotary_dim = 0; @@ -267,22 +269,20 @@ Status CheckInputs(const T* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } + // Skip prompt type detection when total_seqlen is on GPU (graph capture mode) bool is_subsequent_prompt = false; - if (sequence_length > 1 && sequence_length != total_sequence_length) { - if (batch_size != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "batch_size must be 1 when sequence_length > 1 and past context is given."); + bool is_first_prompt = false; + if (is_total_seqlen_on_cpu) { + if (sequence_length > 1 && sequence_length != total_sequence_length) { + if (batch_size != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "batch_size must be 1 when sequence_length > 1 and past context is given."); + } + is_subsequent_prompt = true; } - is_subsequent_prompt = true; - } - bool is_first_prompt; - if (is_subsequent_prompt) { - is_first_prompt = false; // irrelevant for interactive decoding - } else { - // If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt is_first_prompt = (sequence_length == total_sequence_length); - if (!is_first_prompt && sequence_length != 1) { + if (!is_subsequent_prompt && !is_first_prompt && sequence_length != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "sequence_length shall be 1 when it is not prompt."); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 1e69928f2a7ce..00f60142df159 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -31,9 +31,11 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform); const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape"); + if (use_seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::None); + } // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output if (prepare_indirect_dispatch_) { - shader.AddInput("seqlen_k", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -43,7 +45,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let sequence_id = output_indices[2];\n" " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; - if (prepare_indirect_dispatch_) { + if (use_seqlen_k_) { shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; @@ -105,9 +107,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt // Determine if we need to prepare indirect dispatch bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + bool use_seqlen_k = (seqlen_k != nullptr); CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, - prepare_indirect_dispatch}; + prepare_indirect_dispatch, use_seqlen_k}; if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); @@ -119,7 +122,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt {V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}}); } - if (prepare_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } @@ -137,7 +140,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddIndices(std::move(copy_kv_shape)); program.SetDispatchGroupSize(static_cast((copy_size + 63) / 64)) .SetWorkgroupSize(64) - .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch) + .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k) .AddUniformVariables({{static_cast(copy_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(parameters.kv_sequence_length_)}, @@ -167,6 +170,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } + if (use_seqlen_k_) { + shader.AddInput("seqlens_k", ShaderUsage::None); + } shader.AddOutput("output", ShaderUsage::UseUniform); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", @@ -176,7 +182,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), - WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_)); + WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); } Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -349,10 +356,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); + const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + if (parameters.sequence_length_ > 1) { const uint32_t tile_size = 64; // For encode path, use the original CopyKVCache without indirect dispatch preparation - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr)); bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -364,24 +373,27 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, - is_nvidia}; + is_nvidia, + use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); if (has_attention_bias) { program.AddInputs({{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); } + if (use_seqlen_k) { + program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::None}}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, - {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, {static_cast(parameters.n_reps)}, {alpha}, {num_seq_tile}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index f372aeed0e563..9599c10533351 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -18,8 +18,8 @@ using namespace onnxruntime::webgpu; class CopyKVCacheProgram final : public Program { public: CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, - bool prepare_indirect_dispatch = false) - : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch) { + bool prepare_indirect_dispatch = false, bool use_seqlen_k = false) + : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -34,6 +34,7 @@ class CopyKVCacheProgram final : public Program { bool has_past_; bool kv_BNSH_; bool prepare_indirect_dispatch_; + bool use_seqlen_k_; }; class FlashAttentionProgram final : public Program { @@ -45,7 +46,8 @@ class FlashAttentionProgram final : public Program { int qkv_head_size, int qkv_num_heads, bool is_unidirectional, - bool is_nvidia) + bool is_nvidia, + bool use_seqlen_k = false) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), @@ -53,7 +55,8 @@ class FlashAttentionProgram final : public Program { qkv_head_size_(qkv_head_size), qkv_num_heads_(qkv_num_heads), is_unidirectional_(is_unidirectional), - is_nvidia_(is_nvidia) { + is_nvidia_(is_nvidia), + use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -61,7 +64,6 @@ class FlashAttentionProgram final : public Program { WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); @@ -74,6 +76,7 @@ class FlashAttentionProgram final : public Program { int qkv_num_heads_; bool is_unidirectional_; bool is_nvidia_; + bool use_seqlen_k_; }; class FlashAttentionDecodeQKTProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index 0674702bd6030..a5922ec9512fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -6,10 +6,23 @@ #param prefer_subgroupshuffle #param qkv_head_size #param qkv_num_heads +#param use_seqlen_k const head_size : u32 = qkv_head_size; const num_heads : u32 = qkv_num_heads; +#if use_seqlen_k +// When graph capture is enabled, total_sequence_length is read from GPU buffer +fn get_total_sequence_length() -> u32 { + return u32(seqlens_k[0]) + 1u; +} +#else +// When graph capture is disabled, total_sequence_length comes from uniforms +fn get_total_sequence_length() -> u32 { + return uniforms.total_sequence_length; +} +#endif + #if is_fp16 const min_value = q_element_t(-65504.0); #else @@ -45,17 +58,17 @@ fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < uniforms.total_sequence_length); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { +fn loadv(v_start : u32, head_idx : u32, local_idx : u32, v_step : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,96] let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { + for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < uniforms.total_sequence_length); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); v_tile[slot][idx % head_size_vec] = val; } } @@ -93,12 +106,12 @@ fn writeo(o_idx_global : u32, head_idx : u32) { #if has_attention_bias fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.total_sequence_length) { + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= get_total_sequence_length()) { return vec4(0); } - let offset_base = head_idx * uniforms.new_sequence_length * uniforms.total_sequence_length + q_idx_global * uniforms.total_sequence_length; + let offset_base = head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); let offset = offset_base + k_idx_global; - let offset_max = offset_base + uniforms.total_sequence_length; + let offset_max = offset_base + get_total_sequence_length(); let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -141,16 +154,18 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; + let total_sequence_length = get_total_sequence_length(); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let max_causal_len_for_workgroup = uniforms.past_sequence_length + + let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; - let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); - let seq_causal_length = uniforms.past_sequence_length + q_idx_global + 1; + let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); + let seq_causal_length = past_sequence_length + q_idx_global + 1; #else - let loop_bound = uniforms.total_sequence_length; - let seq_causal_length = uniforms.total_sequence_length; + let loop_bound = total_sequence_length; + let seq_causal_length = total_sequence_length; #endif for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 49cc0209785c5..27215ce144816 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -19,18 +19,6 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -ONNX_OPERATOR_KERNEL_EX( - GroupQueryAttention, - kMSDomain, - 1, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()) - .MayInplace(3, 1) - .MayInplace(4, 2) - .InputMemoryType(OrtMemTypeCPUInput, 6), - GroupQueryAttention); - Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); @@ -287,7 +275,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& !use_sliding_window && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + present_value, parameters, context, seqlen_k); } TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, @@ -318,6 +306,29 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& present_value, parameters, context, head_sink, seqlen_k, local_window_size_); } +KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture) { + KernelDefBuilder builder; + builder.SetName("GroupQueryAttention") + .SetDomain(kMSDomain) + .SinceVersion(1) + .Provider(kWebGpuExecutionProvider) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .MayInplace(3, 1) + .MayInplace(4, 2); + + // Only set InputMemoryType to CPU when graph capture is disabled + if (!enable_graph_capture) { + builder.InputMemoryType(OrtMemTypeCPUInput, 6); + } + + return KernelCreateInfo( + builder.Build(), + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + out = std::make_unique(info); + return Status::OK(); + }); +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 1fb1e1ffc91fd..51280f18da4b9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -71,6 +71,8 @@ class GroupQueryAttention final : public WebGpuKernel { Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; }; +KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 25cc13b3ea1df..b7d70eb99f23e 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/group_query_attention.h" #include "core/framework/op_kernel.h" @@ -34,7 +35,7 @@ KernelCreateInfo BuildKernelCreateInfo() { return info; } -Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable_graph_capture) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing // BuildKernelCreateInfo, @@ -44,7 +45,6 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -61,6 +61,10 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); } } + + // Register GroupQueryAttention with conditional InputMemoryType based on graph capture + ORT_RETURN_IF_ERROR(kernel_registry.Register(CreateGroupQueryAttentionKernelInfo(enable_graph_capture))); + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h index d73859de78239..a4fcfb8390798 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -13,7 +13,7 @@ namespace webgpu { template KernelCreateInfo BuildKernelCreateInfo(); -Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable_graph_capture = false); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 135782ad577c4..878968529e6a7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -780,7 +780,7 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals ORT_THROW_IF_ERROR(kernel_registry->Register(CreateCastKernelInfo<23>(enable_graph_capture))); #ifndef DISABLE_CONTRIB_OPS - Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry); + Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry, enable_graph_capture); ORT_ENFORCE(status.IsOK(), "Failed to register WebGPU contrib kernels: " + status.ErrorMessage()); #endif