Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,11 @@ def __init__(self):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)

class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
scale_factor = 0.75289

class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
Expand Down
102 changes: 86 additions & 16 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
byt5: bool


class SelfAttentionRef(nn.Module):
Expand Down Expand Up @@ -161,6 +162,30 @@ def forward(
x = self.individual_token_refiner(x, c, mask)
return x


class ByT5Mapper(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
super().__init__()
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
self.use_res = use_res
self.act_fn = nn.GELU()

def forward(self, x):
if self.use_res:
res = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_res:
x2 = x2 + res
return x2

class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
Expand All @@ -185,9 +210,13 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)

self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None

self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
Expand Down Expand Up @@ -215,6 +244,18 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
]
)

if params.byt5:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=self.hidden_size,
use_res=False,
dtype=dtype, device=device, operations=operations
)
else:
self.byt5_in = None

if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)

Expand All @@ -226,7 +267,8 @@ def forward_orig(
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
y: Tensor = None,
txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
Expand All @@ -250,13 +292,17 @@ def forward_orig(

if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
if self.vector_in is not None:
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
else:
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None

Expand All @@ -269,6 +315,12 @@ def forward_orig(

txt = self.txt_in(txt, timesteps, txt_mask)

if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)

ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)

Expand Down Expand Up @@ -328,12 +380,16 @@ def block_wrap(args):

img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)

shape = initial_shape[-3:]
shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
if img.ndim == 8:
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
else:
img = img.permute(0, 3, 1, 4, 2, 5)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img

def img_ids(self, x):
Expand All @@ -348,16 +404,30 @@ def img_ids(self, x):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)

def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def img_ids_2d(self, x):
bs, c, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)

def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)

def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out
136 changes: 136 additions & 0 deletions comfy/ldm/hunyuan_video/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
import comfy.ops
ops = comfy.ops.disable_weight_init


class PixelShuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
self.ratio = (in_dim << 2) // out_dim

def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h >> 1, w >> 1
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)


class PixelUnshuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
self.scale = (out_dim << 2) // in_dim

def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h << 1, w << 1
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
return y + r


class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)

self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.down.append(stage)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)

self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)

def forward(self, x):
x = self.conv_in(x)

for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)

x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))

b, c, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, h, w).mean(2)

return self.conv_out(F.silu(self.norm_out(x))) + skip


class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks

ch = block_out_channels[0]
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)

self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)

self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()

for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.up.append(stage)

self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)

def forward(self, z):
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))

for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)

return self.conv_out(F.silu(self.norm_out(x)))
24 changes: 24 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,3 +1408,27 @@ def extra_conds_shapes(self, **kwargs):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out

class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)

guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))

return out
28 changes: 20 additions & 8 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,32 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):

if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
dit_config["vec_in_dim"] = 768
dit_config["axes_dim"] = [16, 56, 56]
else:
dit_config["vec_in_dim"] = None
dit_config["axes_dim"] = [64, 64]

dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
dit_config["byt5"] = True
else:
dit_config["byt5"] = False

guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
Expand Down
Loading
Loading