Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,66 @@ class HunyuanImage21Refiner(LatentFormat):
latent_dimensions = 3
scale_factor = 1.03682

def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out

def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z

class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
[ 0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[ 0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[ 0.0530, 0.0413, 0.0253],
[ 0.0283, 0.0251, 0.0339],
[ 0.0277, -0.0372, -0.0093],
[ 0.0393, 0.0944, 0.1131],
[ 0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[ 0.0468, 0.0436, 0.0203],
[ 0.0354, 0.0439, -0.0233],
[ 0.0090, 0.0123, 0.0346],
[ 0.0382, 0.0029, 0.0217],
[ 0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[ 0.0414, 0.0479, 0.0529],
[ 0.0355, 0.0612, -0.0247],
[ 0.0147, 0.0264, 0.0174],
[ 0.0438, 0.0038, 0.0542],
[ 0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[ 0.0005, -0.0109, 0.0253],
[ 0.0296, 0.0591, 0.0353],
[ 0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[ 0.0005, -0.0106, 0.0242]
]

latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682

class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
Expand Down
54 changes: 47 additions & 7 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention


from dataclasses import dataclass
from einops import repeat

Expand Down Expand Up @@ -42,6 +41,8 @@ class HunyuanVideoParams:
guidance_embed: bool
byt5: bool
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int


class SelfAttentionRef(nn.Module):
Expand Down Expand Up @@ -157,7 +158,10 @@ def forward(
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1]

c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
Expand Down Expand Up @@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}

params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
Expand Down Expand Up @@ -266,6 +274,18 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)

# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from comfy.ldm.wan.model import MLPProj
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None

def forward_orig(
self,
img: Tensor,
Expand All @@ -276,6 +296,7 @@ def forward_orig(
timesteps: Tensor,
y: Tensor = None,
txt_byt5=None,
clip_fea=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
Expand Down Expand Up @@ -331,12 +352,31 @@ def forward_orig(

txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)

if self.cond_type_embedding is not None:
self.cond_type_embedding.to(txt.device)
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
txt = txt + cond_emb.to(txt.dtype)

if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)

if clip_fea is not None:
txt_vision_states = self.vision_in(clip_fea)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)

ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)

Expand Down Expand Up @@ -430,20 +470,20 @@ def img_ids_2d(self, x):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)

def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out
120 changes: 120 additions & 0 deletions comfy/ldm/hunyuan_video/upsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher

class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)

class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)

def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y


class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels

ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)

self.up = nn.ModuleList()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)

self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)

def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)

# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)

out = self.conv_out(F.silu(self.norm_out(x)))
return out

UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}

class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()

self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)

def get_sd(self):
return self.model.state_dict()

def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))
Loading
Loading