diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 72af3d5bb1e3..fbd8d4196d14 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -228,6 +228,7 @@ def forward_orig( y: Tensor, guidance: Tensor = None, guiding_frame_index=None, + ref_latent=None, control=None, transformer_options={}, ) -> Tensor: @@ -238,6 +239,14 @@ def forward_orig( img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if ref_latent is not None: + ref_latent_ids = self.img_ids(ref_latent) + ref_latent = self.img_in(ref_latent) + img = torch.cat([ref_latent, img], dim=-2) + ref_latent_ids[..., 0] = -1 + ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) + img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) + if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) @@ -313,6 +322,8 @@ def block_wrap(args): img[:, : img_len] += add img = img[:, : img_len] + if ref_latent is not None: + img = img[:, ref_latent.shape[1]:] img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) @@ -324,7 +335,7 @@ def block_wrap(args): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): + def img_ids(self, x): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -334,7 +345,11 @@ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, g img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) - img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) + + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs, c, t, h, w = x.shape + img_ids = self.img_ids(x) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 6d27930dc139..047861593639 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -924,6 +924,10 @@ def extra_conds(self, **kwargs): if guiding_frame_index is not None: out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + ref_latent = kwargs.get("ref_latent", None) + if ref_latent is not None: + out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent)) + return out def scale_latent_inpaint(self, latent_image, **kwargs): diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 1325985b2e27..25b21b1b8b2b 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -14,7 +14,7 @@ def INPUT_TYPES(s): "model": ("MODEL",), "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), + "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), } } RETURN_TYPES = ("MODEL",) @@ -41,7 +41,7 @@ def pre_cfg_function(args): guidance = cond - uncond - if momentum > 0: + if momentum != 0: if not torch.is_tensor(running_avg): running_avg = guidance else: diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 504010ad034c..d7278e7a7d86 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -77,7 +77,7 @@ def INPUT_TYPES(s): "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "guidance_type": (["v1 (concat)", "v2 (replace)"], ) + "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -101,10 +101,12 @@ def encode(self, positive, vae, width, height, length, batch_size, guidance_type if guidance_type == "v1 (concat)": cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} - else: + elif guidance_type == "v2 (replace)": cond = {'guiding_frame_index': 0} latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} positive = node_helpers.conditioning_set_values(positive, cond)