Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def forward_orig(
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
Expand All @@ -238,6 +239,14 @@ def forward_orig(
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))

if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)

if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
Expand Down Expand Up @@ -313,6 +322,8 @@ def block_wrap(args):
img[:, : img_len] += add

img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]

img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)

Expand All @@ -324,7 +335,7 @@ def block_wrap(args):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img

def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
Expand All @@ -334,7 +345,11 @@ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, g
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)

def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out
4 changes: 4 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,10 @@ def extra_conds(self, **kwargs):
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))

ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))

return out

def scale_latent_inpaint(self, latent_image, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions comfy_extras/nodes_apg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def INPUT_TYPES(s):
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
Expand All @@ -41,7 +41,7 @@ def pre_cfg_function(args):

guidance = cond - uncond

if momentum > 0:
if momentum != 0:
if not torch.is_tensor(running_avg):
running_avg = guidance
else:
Expand Down
6 changes: 4 additions & 2 deletions comfy_extras/nodes_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def INPUT_TYPES(s):
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
Expand All @@ -101,10 +101,12 @@ def encode(self, positive, vae, width, height, length, batch_size, guidance_type

if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}

positive = node_helpers.conditioning_set_values(positive, cond)

Expand Down
Loading