diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 383f95a4636..74d08c30409 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -537,10 +537,9 @@ def forward( t.reshape(*orig_shape, -1) for t in torch.chunk( q_gate.view(*orig_shape, self.num_heads, -1), 2, dim=-1) ] - ### TODO: avoid the redundant split and concat - qkv = torch.concat([q, k, v], dim=-1) + else: + q, k, v = qkv, None, None - q, k, v = qkv, None, None q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) diff --git a/tensorrt_llm/_torch/modules/qk_norm_attention.py b/tensorrt_llm/_torch/modules/qk_norm_attention.py index b116394989e..e69fb33d1d2 100644 --- a/tensorrt_llm/_torch/modules/qk_norm_attention.py +++ b/tensorrt_llm/_torch/modules/qk_norm_attention.py @@ -249,6 +249,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], else: return q, k, v - assert k is None and v is None, "The input should be a concatenated qkv tensor to apply_qk_norm_rope" qkv = q + if k is not None and v is not None: + qkv = torch.concat([q, k, v], dim=-1) return self.apply_qk_norm_rope(qkv, position_ids)