Skip to content

Commit 560d38f

Browse files
Wan2.2 fun control support. (Comfy-Org#9292)
1 parent e1d4f36 commit 560d38f

File tree

4 files changed

+91
-1
lines changed

4 files changed

+91
-1
lines changed

comfy/ldm/wan/model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def __init__(self,
391391
cross_attn_norm=True,
392392
eps=1e-6,
393393
flf_pos_embed_token_number=None,
394+
in_dim_ref_conv=None,
394395
image_model=None,
395396
device=None,
396397
dtype=None,
@@ -484,6 +485,11 @@ def __init__(self,
484485
else:
485486
self.img_emb = None
486487

488+
if in_dim_ref_conv is not None:
489+
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
490+
else:
491+
self.ref_conv = None
492+
487493
def forward_orig(
488494
self,
489495
x,
@@ -526,6 +532,13 @@ def forward_orig(
526532
e = e.reshape(t.shape[0], -1, e.shape[-1])
527533
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
528534

535+
full_ref = None
536+
if self.ref_conv is not None:
537+
full_ref = kwargs.get("reference_latent", None)
538+
if full_ref is not None:
539+
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
540+
x = torch.concat((full_ref, x), dim=1)
541+
529542
# context
530543
context = self.text_embedding(context)
531544

@@ -552,6 +565,9 @@ def block_wrap(args):
552565
# head
553566
x = self.head(x, e)
554567

568+
if full_ref is not None:
569+
x = x[:, full_ref.shape[1]:]
570+
555571
# unpatchify
556572
x = self.unpatchify(x, grid_sizes)
557573
return x
@@ -570,6 +586,9 @@ def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tra
570586
x = torch.cat([x, time_dim_concat], dim=2)
571587
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
572588

589+
if self.ref_conv is not None and "reference_latent" in kwargs:
590+
t_len += 1
591+
573592
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
574593
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
575594
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)

comfy/model_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,11 @@ def concat_cond(self, **kwargs):
11241124
mask = mask.repeat(1, 4, 1, 1, 1)
11251125
mask = utils.resize_to_batch_size(mask, noise.shape[0])
11261126

1127-
return torch.cat((mask, image), dim=1)
1127+
concat_mask_index = kwargs.get("concat_mask_index", 0)
1128+
if concat_mask_index != 0:
1129+
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
1130+
else:
1131+
return torch.cat((mask, image), dim=1)
11281132

11291133
def extra_conds(self, **kwargs):
11301134
out = super().extra_conds(**kwargs)
@@ -1140,6 +1144,10 @@ def extra_conds(self, **kwargs):
11401144
if time_dim_concat is not None:
11411145
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
11421146

1147+
reference_latents = kwargs.get("reference_latents", None)
1148+
if reference_latents is not None:
1149+
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
1150+
11431151
return out
11441152

11451153

comfy/model_detection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
373373
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
374374
if flf_weight is not None:
375375
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
376+
377+
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
378+
if ref_conv_weight is not None:
379+
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
380+
376381
return dit_config
377382

378383
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

comfy_extras/nodes_wan.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,63 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, sta
103103
out_latent["samples"] = latent
104104
return (positive, negative, out_latent)
105105

106+
class Wan22FunControlToVideo:
107+
@classmethod
108+
def INPUT_TYPES(s):
109+
return {"required": {"positive": ("CONDITIONING", ),
110+
"negative": ("CONDITIONING", ),
111+
"vae": ("VAE", ),
112+
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
113+
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
114+
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
115+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
116+
},
117+
"optional": {"ref_image": ("IMAGE", ),
118+
"control_video": ("IMAGE", ),
119+
# "start_image": ("IMAGE", ),
120+
}}
121+
122+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
123+
RETURN_NAMES = ("positive", "negative", "latent")
124+
FUNCTION = "encode"
125+
126+
CATEGORY = "conditioning/video_models"
127+
128+
def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None):
129+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
130+
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
131+
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
132+
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
133+
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
134+
135+
if start_image is not None:
136+
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
137+
concat_latent_image = vae.encode(start_image[:, :, :, :3])
138+
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
139+
mask[:, :, :start_image.shape[0] + 3] = 0.0
140+
141+
ref_latent = None
142+
if ref_image is not None:
143+
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
144+
ref_latent = vae.encode(ref_image[:, :, :, :3])
145+
146+
if control_video is not None:
147+
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
148+
concat_latent_image = vae.encode(control_video[:, :, :, :3])
149+
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
150+
151+
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
152+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
153+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
154+
155+
if ref_latent is not None:
156+
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
157+
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
158+
159+
out_latent = {}
160+
out_latent["samples"] = latent
161+
return (positive, negative, out_latent)
162+
106163
class WanFirstLastFrameToVideo:
107164
@classmethod
108165
def INPUT_TYPES(s):
@@ -733,6 +790,7 @@ def encode(self, vae, width, height, length, batch_size, start_image=None):
733790
"WanTrackToVideo": WanTrackToVideo,
734791
"WanImageToVideo": WanImageToVideo,
735792
"WanFunControlToVideo": WanFunControlToVideo,
793+
"Wan22FunControlToVideo": Wan22FunControlToVideo,
736794
"WanFunInpaintToVideo": WanFunInpaintToVideo,
737795
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
738796
"WanVaceToVideo": WanVaceToVideo,

0 commit comments

Comments
 (0)