Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,14 @@ Status CheckInputs(const T* query,
"seqlens_k must be shape (batch_size).");
}

// Set present sequence length from input total_seqlen tensor
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
}
int total_sequence_length = *((*total_seqlen).template Data<int32_t>());

// When graph capture is enabled, total_seqlen is on GPU and cannot be read. Skip validation.
const bool is_total_seqlen_on_cpu = (total_seqlen->Location().device.Type() == OrtDevice::CPU);
int total_sequence_length = is_total_seqlen_on_cpu ? *((*total_seqlen).template Data<int32_t>()) : 0;
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);

int rotary_dim = 0;
Expand All @@ -267,22 +269,20 @@ Status CheckInputs(const T* query,
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
}

// Skip prompt type detection when total_seqlen is on GPU (graph capture mode)
bool is_subsequent_prompt = false;
if (sequence_length > 1 && sequence_length != total_sequence_length) {
if (batch_size != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"batch_size must be 1 when sequence_length > 1 and past context is given.");
bool is_first_prompt = false;
if (is_total_seqlen_on_cpu) {
if (sequence_length > 1 && sequence_length != total_sequence_length) {
if (batch_size != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"batch_size must be 1 when sequence_length > 1 and past context is given.");
}
is_subsequent_prompt = true;
}
is_subsequent_prompt = true;
}

bool is_first_prompt;
if (is_subsequent_prompt) {
is_first_prompt = false; // irrelevant for interactive decoding
} else {
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
is_first_prompt = (sequence_length == total_sequence_length);
if (!is_first_prompt && sequence_length != 1) {
if (!is_subsequent_prompt && !is_first_prompt && sequence_length != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sequence_length shall be 1 when it is not prompt.");
}
Expand Down
32 changes: 22 additions & 10 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform);
const auto& copy_kv_shape = shader.AddIndices("copy_kv_shape");
if (use_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::None);
}
// If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
if (prepare_indirect_dispatch_) {
shader.AddInput("seqlen_k", ShaderUsage::None);
shader.AddOutput("indirect_buffer", ShaderUsage::None);
}

Expand All @@ -43,7 +45,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
" let sequence_id = output_indices[2];\n"
" let num_head_id = output_indices[1];\n"
" let batch = output_indices[0];\n";
if (prepare_indirect_dispatch_) {
if (use_seqlen_k_) {
shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n";
} else {
shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n";
Expand Down Expand Up @@ -105,9 +107,10 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt

// Determine if we need to prepare indirect dispatch
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
bool use_seqlen_k = (seqlen_k != nullptr);

CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH,
prepare_indirect_dispatch};
prepare_indirect_dispatch, use_seqlen_k};
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) {
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
Expand All @@ -119,7 +122,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
{V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}});
}

if (prepare_indirect_dispatch) {
if (use_seqlen_k) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}

Expand All @@ -137,7 +140,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
program.AddIndices(std::move(copy_kv_shape));
program.SetDispatchGroupSize(static_cast<uint32_t>((copy_size + 63) / 64))
.SetWorkgroupSize(64)
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch)
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k)
.AddUniformVariables({{static_cast<uint32_t>(copy_size)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
Expand Down Expand Up @@ -167,6 +170,9 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
if (use_seqlen_k_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
shader.AddOutput("output", ShaderUsage::UseUniform);

return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template",
Expand All @@ -176,7 +182,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
WGSL_TEMPLATE_PARAMETER(prefer_subgroupshuffle, !is_nvidia_),
WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_),
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_));
WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_),
WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_));
}

Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
Expand Down Expand Up @@ -349,10 +356,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);

const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();

if (parameters.sequence_length_ > 1) {
const uint32_t tile_size = 64;
// For encode path, use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
Expand All @@ -364,24 +373,27 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.head_size_,
parameters.num_heads_,
parameters.is_unidirectional_,
is_nvidia};
is_nvidia,
use_seqlen_k};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}});
if (has_attention_bias) {
program.AddInputs({{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}});
}
if (use_seqlen_k) {
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::None}});
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}});
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;
const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, use_seqlen_k)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
{static_cast<uint32_t>(parameters.total_sequence_length_ - parameters.kv_sequence_length_)},
{static_cast<uint32_t>(parameters.n_reps)},
{alpha},
{num_seq_tile}});
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using namespace onnxruntime::webgpu;
class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
bool prepare_indirect_dispatch = false)
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch) {
bool prepare_indirect_dispatch = false, bool use_seqlen_k = false)
: Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch), use_seqlen_k_(use_seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -34,6 +34,7 @@ class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
bool has_past_;
bool kv_BNSH_;
bool prepare_indirect_dispatch_;
bool use_seqlen_k_;
};

class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
Expand All @@ -45,23 +46,24 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_head_size,
int qkv_num_heads,
bool is_unidirectional,
bool is_nvidia)
bool is_nvidia,
bool use_seqlen_k = false)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
is_qualcomm_(is_qualcomm),
is_fp16_(is_fp16),
qkv_head_size_(qkv_head_size),
qkv_num_heads_(qkv_num_heads),
is_unidirectional_(is_unidirectional),
is_nvidia_(is_nvidia) {
is_nvidia_(is_nvidia),
use_seqlen_k_(use_seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"new_sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"num_seq_tile", ProgramUniformVariableDataType::Uint32});
Expand All @@ -74,6 +76,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
int qkv_num_heads_;
bool is_unidirectional_;
bool is_nvidia_;
bool use_seqlen_k_;
};

class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
Expand Down
39 changes: 27 additions & 12 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
#param prefer_subgroupshuffle
#param qkv_head_size
#param qkv_num_heads
#param use_seqlen_k

const head_size : u32 = qkv_head_size;
const num_heads : u32 = qkv_num_heads;

#if use_seqlen_k
// When graph capture is enabled, total_sequence_length is read from GPU buffer
fn get_total_sequence_length() -> u32 {
return u32(seqlens_k[0]) + 1u;
}
#else
// When graph capture is disabled, total_sequence_length comes from uniforms
fn get_total_sequence_length() -> u32 {
return uniforms.total_sequence_length;
}
#endif

#if is_fp16
const min_value = q_element_t(-65504.0);
#else
Expand Down Expand Up @@ -45,17 +58,17 @@ fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) {
let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < uniforms.total_sequence_length);
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length());
k_tile[slot][idx % head_size_vec] = val;
}
}

fn loadv(v_start : u32, head_idx : u32, local_idx : u32, k_step : u32) {
fn loadv(v_start : u32, head_idx : u32, local_idx : u32, v_step : u32) {
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < uniforms.total_sequence_length);
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length());
v_tile[slot][idx % head_size_vec] = val;
}
}
Expand Down Expand Up @@ -93,12 +106,12 @@ fn writeo(o_idx_global : u32, head_idx : u32) {
#if has_attention_bias
fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4<q_element_t> {
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.total_sequence_length) {
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= get_total_sequence_length()) {
return vec4<q_element_t>(0);
}
let offset_base = head_idx * uniforms.new_sequence_length * uniforms.total_sequence_length + q_idx_global * uniforms.total_sequence_length;
let offset_base = head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length();
let offset = offset_base + k_idx_global;
let offset_max = offset_base + uniforms.total_sequence_length;
let offset_max = offset_base + get_total_sequence_length();
let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]);
let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]);
Expand Down Expand Up @@ -141,16 +154,18 @@ $MAIN {

var previous_max : q_element_t = min_value;
var previous_denom : q_element_t = 0;
let total_sequence_length = get_total_sequence_length();

#if is_unidirectional
// If attention is unidirectional, set the loop bound to enforce causal masking.
let max_causal_len_for_workgroup = uniforms.past_sequence_length +
let past_sequence_length = total_sequence_length - uniforms.new_sequence_length;
let max_causal_len_for_workgroup = past_sequence_length +
(workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x;
let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup);
let seq_causal_length = uniforms.past_sequence_length + q_idx_global + 1;
let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup);
let seq_causal_length = past_sequence_length + q_idx_global + 1;
#else
let loop_bound = uniforms.total_sequence_length;
let seq_causal_length = uniforms.total_sequence_length;
let loop_bound = total_sequence_length;
let seq_causal_length = total_sequence_length;
#endif

for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) {
Expand Down
37 changes: 24 additions & 13 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,6 @@
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
GroupQueryAttention,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2)
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);

Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
Expand Down Expand Up @@ -287,7 +275,7 @@
!use_sliding_window &&
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
present_value, parameters, context, seqlen_k);
}

TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
Expand Down Expand Up @@ -318,6 +306,29 @@
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
}

KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture) {
KernelDefBuilder builder;
builder.SetName("GroupQueryAttention")
.SetDomain(kMSDomain)
.SinceVersion(1)
.Provider(kWebGpuExecutionProvider)
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2);

// Only set InputMemoryType to CPU when graph capture is disabled
if (!enable_graph_capture) {
builder.InputMemoryType(OrtMemTypeCPUInput, 6);
}

return KernelCreateInfo(
builder.Build(),
[](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<GroupQueryAttention>(info);

Check warning on line 327 in onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc:327: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
return Status::OK();
});
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class GroupQueryAttention final : public WebGpuKernel {
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
};

KernelCreateInfo CreateGroupQueryAttentionKernelInfo(bool enable_graph_capture);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading
Loading