Skip to content

Commit 25270d1

Browse files
committed
fix: Avoid unnecessary concat in attn_output_gate case.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent e9e4632 commit 25270d1

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,9 @@ def forward(
537537
t.reshape(*orig_shape, -1) for t in torch.chunk(
538538
q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1)
539539
]
540-
### TODO: avoid the redundant split and concat
541-
qkv = torch.concat([q, k, v], dim=-1)
540+
else:
541+
q, k, v = qkv, None, None
542542

543-
q, k, v = qkv, None, None
544543
q, k, v = self.apply_rope(q, k, v, position_ids)
545544
q, k, v = self.convert_qkv(q, k, v)
546545

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
249249
else:
250250
return q, k, v
251251

252-
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
253252
qkv = q
253+
if k is not None and v is not None:
254+
qkv = torch.concat([q, k, v], dim=-1)
254255
return self.apply_qk_norm_rope(qkv, position_ids)

0 commit comments

Comments
 (0)