Skip to content

Commit ad19a06

Browse files
Make SLG nodes work on Qwen Image model. (Comfy-Org#9345)
1 parent 5d65d67 commit ad19a06

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def forward(
356356
context,
357357
attention_mask=None,
358358
guidance: torch.Tensor = None,
359+
transformer_options={},
359360
**kwargs
360361
):
361362
timestep = timesteps
@@ -383,14 +384,26 @@ def forward(
383384
else self.time_text_embed(timestep, guidance, hidden_states)
384385
)
385386

386-
for block in self.transformer_blocks:
387-
encoder_hidden_states, hidden_states = block(
388-
hidden_states=hidden_states,
389-
encoder_hidden_states=encoder_hidden_states,
390-
encoder_hidden_states_mask=encoder_hidden_states_mask,
391-
temb=temb,
392-
image_rotary_emb=image_rotary_emb,
393-
)
387+
patches_replace = transformer_options.get("patches_replace", {})
388+
blocks_replace = patches_replace.get("dit", {})
389+
390+
for i, block in enumerate(self.transformer_blocks):
391+
if ("double_block", i) in blocks_replace:
392+
def block_wrap(args):
393+
out = {}
394+
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
395+
return out
396+
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
397+
hidden_states = out["img"]
398+
encoder_hidden_states = out["txt"]
399+
else:
400+
encoder_hidden_states, hidden_states = block(
401+
hidden_states=hidden_states,
402+
encoder_hidden_states=encoder_hidden_states,
403+
encoder_hidden_states_mask=encoder_hidden_states_mask,
404+
temb=temb,
405+
image_rotary_emb=image_rotary_emb,
406+
)
394407

395408
hidden_states = self.norm_out(hidden_states, temb)
396409
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)