Skip to content

Commit 9742542

Browse files
authored
Un-hardcode chroma patch_size (Comfy-Org#8840)
1 parent c5de495 commit 9742542

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

comfy/ldm/chroma/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,18 +254,17 @@ def block_wrap(args):
254254

255255
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
256256
bs, c, h, w = x.shape
257-
patch_size = 2
258-
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
257+
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
259258

260-
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
259+
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size)
261260

262-
h_len = ((h + (patch_size // 2)) // patch_size)
263-
w_len = ((w + (patch_size // 2)) // patch_size)
261+
h_len = ((h + (self.patch_size // 2)) // self.patch_size)
262+
w_len = ((w + (self.patch_size // 2)) // self.patch_size)
264263
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
265264
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
266265
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
267266
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
268267

269268
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
270269
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
271-
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
270+
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]

0 commit comments

Comments
 (0)