Skip to content

Commit c409456

Browse files
committed
fix: Avoid unnecessary concat in attn_output_gate case.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 948b8b9 commit c409456

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,9 @@ def forward(
537537
q, gate = torch.chunk(q_gate, 2, dim=-1)
538538
q = q.reshape(*orig_shape, -1)
539539
gate = gate.reshape(*orig_shape, -1)
540-
### TODO: avoid the redundant split and concat
541-
qkv = torch.concat([q, k, v], dim=-1)
540+
else:
541+
q, k, v = qkv, None, None
542542

543-
q, k, v = qkv, None, None
544543
q, k, v = self.apply_rope(q, k, v, position_ids)
545544
q, k, v = self.convert_qkv(q, k, v)
546545

tensorrt_llm/_torch/modules/qk_norm_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
252252
else:
253253
return q, k, v
254254

255-
assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope"
256255
qkv = q
256+
if k is not None and v is not None:
257+
qkv = torch.concat([q, k, v], dim=-1)
257258
return self.apply_qk_norm_rope(qkv, position_ids)

0 commit comments

Comments
 (0)