From 72e22e86c358b6e831b55fd4d37c63b305c71d9b Mon Sep 17 00:00:00 2001 From: chenxiao Date: Sat, 21 Jun 2025 11:53:58 +0800 Subject: [PATCH] Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) --- src/diffusers/models/transformers/transformer_cosmos.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 6c312b7a5a3f..2ffb4ae41b33 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -186,9 +186,9 @@ def __call__( key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2) # 4. Prepare for GQA - query_idx = torch.tensor(query.size(3), device=query.device) - key_idx = torch.tensor(key.size(3), device=key.device) - value_idx = torch.tensor(value.size(3), device=value.device) + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) key = key.repeat_interleave(query_idx // key_idx, dim=3) value = value.repeat_interleave(query_idx // value_idx, dim=3)