From 25270d117dc4df029457938e8f1402c0b75e765d Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 1 Oct 2025 01:51:41 +0000 Subject: [PATCH] fix: Avoid unnecessary concat in attn_output_gate case. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 5 ++--- tensorrt_llm/_torch/modules/qk_norm_attention.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 383f95a4636..74d08c30409 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -537,10 +537,9 @@ def forward( t.reshape(*orig_shape, -1) for t in torch.chunk( q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1) ] - ### TODO: avoid the redundant split and concat - qkv = torch.concat([q, k, v], dim=-1) + else: + q, k, v = qkv, None, None - q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index b116394989e..e69fb33d1d2 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -249,6 +249,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], else: return q, k, v - assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" qkv = q + if k is not None and v is not None: + qkv = torch.concat([q, k, v], dim=-1) return self.apply_qk_norm_rope(qkv, position_ids)