@@ -253,7 +253,10 @@ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
253253 to_concat = []
254254 for c in self .extra_concat_orig :
255255 c = c .to (self .cond_hint .device )
256- c = comfy .utils .common_upscale (c , self .cond_hint .shape [3 ], self .cond_hint .shape [2 ], self .upscale_algorithm , "center" )
256+ c = comfy .utils .common_upscale (c , self .cond_hint .shape [- 1 ], self .cond_hint .shape [- 2 ], self .upscale_algorithm , "center" )
257+ if c .ndim < self .cond_hint .ndim :
258+ c = c .unsqueeze (2 )
259+ c = comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [2 ], dim = 2 )
257260 to_concat .append (comfy .utils .repeat_to_batch_size (c , self .cond_hint .shape [0 ]))
258261 self .cond_hint = torch .cat ([self .cond_hint ] + to_concat , dim = 1 )
259262
@@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
585588
586589def load_controlnet_qwen_instantx (sd , model_options = {}):
587590 model_config , operations , load_device , unet_dtype , manual_cast_dtype , offload_device = controlnet_config (sd , model_options = model_options )
588- control_model = comfy .ldm .qwen_image .controlnet .QwenImageControlNetModel (operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
591+ control_latent_channels = sd .get ("controlnet_x_embedder.weight" ).shape [1 ]
592+
593+ extra_condition_channels = 0
594+ concat_mask = False
595+ if control_latent_channels == 68 : #inpaint controlnet
596+ extra_condition_channels = control_latent_channels - 64
597+ concat_mask = True
598+ control_model = comfy .ldm .qwen_image .controlnet .QwenImageControlNetModel (extra_condition_channels = extra_condition_channels , operations = operations , device = offload_device , dtype = unet_dtype , ** model_config .unet_config )
589599 control_model = controlnet_load_state_dict (control_model , sd )
590600 latent_format = comfy .latent_formats .Wan21 ()
591601 extra_conds = []
592- control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
602+ control = ControlNet (control_model , compression_ratio = 1 , latent_format = latent_format , concat_mask = concat_mask , load_device = load_device , manual_cast_dtype = manual_cast_dtype , extra_conds = extra_conds )
593603 return control
594604
595605def convert_mistoline (sd ):
0 commit comments