Skip to content

Commit d044a24

Browse files
Fix default shift and any latent size for qwen image model. (Comfy-Org#9186)
1 parent 5be6fd0 commit d044a24

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
99
from comfy.ldm.modules.attention import optimized_attention_masked
1010
from comfy.ldm.flux.layers import EmbedND
11-
11+
import comfy.ldm.common_dit
1212

1313
class GELU(nn.Module):
1414
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -364,8 +364,9 @@ def forward(
364364

365365
image_rotary_emb = self.pos_embeds(x, context)
366366

367-
orig_shape = x.shape
368-
hidden_states = x.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
367+
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
368+
orig_shape = hidden_states.shape
369+
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
369370
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
370371
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
371372

@@ -396,4 +397,4 @@ def forward(
396397

397398
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
398399
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
399-
return hidden_states.reshape(orig_shape)
400+
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ class QwenImage(supported_models_base.BASE):
12371237

12381238
sampling_settings = {
12391239
"multiplier": 1.0,
1240-
"shift": 2.6,
1240+
"shift": 1.15,
12411241
}
12421242

12431243
memory_usage_factor = 1.8 #TODO

0 commit comments

Comments
 (0)