Skip to content

Conversation

@qjia7
Copy link
Contributor

@qjia7 qjia7 commented Oct 22, 2025

This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT.

The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models.

In this PR, we still get total sequence length from seqlen_k tensor not total_seqlen_tensor tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use total_seqlen_tensor instead of seqlen_k when graph capture enabled.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Oct 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants