Skip to content

Commit ff57793

Browse files
Support InstantX Qwen controlnet. (Comfy-Org#9488)
1 parent f7bd5e5 commit ff57793

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

comfy/controlnet.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import comfy.cldm.mmdit
3737
import comfy.ldm.hydit.controlnet
3838
import comfy.ldm.flux.controlnet
39+
import comfy.ldm.qwen_image.controlnet
3940
import comfy.cldm.dit_embedder
4041
from typing import TYPE_CHECKING
4142
if TYPE_CHECKING:
@@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}):
582583
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
583584
return control
584585

586+
def load_controlnet_qwen_instantx(sd, model_options={}):
587+
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
588+
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
589+
control_model = controlnet_load_state_dict(control_model, sd)
590+
latent_format = comfy.latent_formats.Wan21()
591+
extra_conds = []
592+
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
593+
return control
594+
585595
def convert_mistoline(sd):
586596
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
587597

@@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
655665
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
656666
else:
657667
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
668+
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
669+
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
658670
elif "controlnet_x_embedder.weight" in controlnet_data:
659671
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
672+
660673
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
661674
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
662675

comfy/ldm/qwen_image/controlnet.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import math
3+
4+
from .model import QwenImageTransformer2DModel
5+
6+
7+
class QwenImageControlNetModel(QwenImageTransformer2DModel):
8+
def __init__(
9+
self,
10+
extra_condition_channels=0,
11+
dtype=None,
12+
device=None,
13+
operations=None,
14+
**kwargs
15+
):
16+
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
17+
self.main_model_double = 60
18+
19+
# controlnet_blocks
20+
self.controlnet_blocks = torch.nn.ModuleList([])
21+
for _ in range(len(self.transformer_blocks)):
22+
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
23+
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
24+
25+
def forward(
26+
self,
27+
x,
28+
timesteps,
29+
context,
30+
attention_mask=None,
31+
guidance: torch.Tensor = None,
32+
ref_latents=None,
33+
hint=None,
34+
transformer_options={},
35+
**kwargs
36+
):
37+
timestep = timesteps
38+
encoder_hidden_states = context
39+
encoder_hidden_states_mask = attention_mask
40+
41+
hidden_states, img_ids, orig_shape = self.process_img(x)
42+
hint, _, _ = self.process_img(hint)
43+
44+
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
45+
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
46+
ids = torch.cat((txt_ids, img_ids), dim=1)
47+
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
48+
del ids, txt_ids, img_ids
49+
50+
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
51+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
52+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
53+
54+
if guidance is not None:
55+
guidance = guidance * 1000
56+
57+
temb = (
58+
self.time_text_embed(timestep, hidden_states)
59+
if guidance is None
60+
else self.time_text_embed(timestep, guidance, hidden_states)
61+
)
62+
63+
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
64+
65+
controlnet_block_samples = ()
66+
for i, block in enumerate(self.transformer_blocks):
67+
encoder_hidden_states, hidden_states = block(
68+
hidden_states=hidden_states,
69+
encoder_hidden_states=encoder_hidden_states,
70+
encoder_hidden_states_mask=encoder_hidden_states_mask,
71+
temb=temb,
72+
image_rotary_emb=image_rotary_emb,
73+
)
74+
75+
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
76+
77+
return {"input": controlnet_block_samples[:self.main_model_double]}

0 commit comments

Comments
 (0)