|
8 | 8 | from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps |
9 | 9 | from comfy.ldm.modules.attention import optimized_attention_masked |
10 | 10 | from comfy.ldm.flux.layers import EmbedND |
11 | | - |
| 11 | +import comfy.ldm.common_dit |
12 | 12 |
|
13 | 13 | class GELU(nn.Module): |
14 | 14 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): |
@@ -364,8 +364,9 @@ def forward( |
364 | 364 |
|
365 | 365 | image_rotary_emb = self.pos_embeds(x, context) |
366 | 366 |
|
367 | | - orig_shape = x.shape |
368 | | - hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) |
| 367 | + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) |
| 368 | + orig_shape = hidden_states.shape |
| 369 | + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) |
369 | 370 | hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) |
370 | 371 | hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) |
371 | 372 |
|
@@ -396,4 +397,4 @@ def forward( |
396 | 397 |
|
397 | 398 | hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) |
398 | 399 | hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) |
399 | | - return hidden_states.reshape(orig_shape) |
| 400 | + return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] |
0 commit comments