Skip to content

Commit 8d38ea3

Browse files
Fix bf16 precision issue with qwen image embeddings. (Comfy-Org#9441)
1 parent 5a8f502 commit 8d38ea3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
347347
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
348348
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
349349

350-
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
350+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
351351
img_ids[:, :, 0] = img_ids[:, :, 1] + index
352352
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
353353
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
@@ -397,9 +397,10 @@ def forward(
397397
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
398398

399399
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
400-
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
400+
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
401401
ids = torch.cat((txt_ids, img_ids), dim=1)
402402
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
403+
del ids, txt_ids, img_ids
403404

404405
hidden_states = self.img_in(hidden_states)
405406
encoder_hidden_states = self.txt_norm(encoder_hidden_states)

0 commit comments

Comments
 (0)