Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,60 @@ def forward(self, c, x, **kwargs):
return c_skip, c


class WanCamAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
super(WanCamAdapter, self).__init__()

# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)

# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))

# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
)

def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)

# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)

# Convolution operation
x_conv = self.conv(x_unshuffled)

# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)

# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))

# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)

return out


class WanCamResidualBlock(nn.Module):
def __init__(self, dim, operation_settings={}):
super(WanCamResidualBlock, self).__init__()
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.relu = nn.ReLU(inplace=True)
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))

def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out


class Head(nn.Module):

def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
Expand Down Expand Up @@ -637,3 +691,92 @@ def block_wrap(args):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

class CameraWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""

def __init__(self,
model_type='camera',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
in_dim_control_adapter=24,
device=None,
dtype=None,
operations=None,
):

super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}

self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)


def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
camera_conditions = None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)

# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))

# context
context = self.text_embedding(context)

context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]

patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)

# head
x = self.head(x, e)

# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
11 changes: 11 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,17 @@ def extra_conds(self, **kwargs):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out

class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
self.image_to_video = image_to_video

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
camera_conditions = kwargs.get("camera_conditions", None)
if camera_conditions is not None:
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out

class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
Expand Down
2 changes: 2 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
Expand Down
12 changes: 11 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,16 @@ def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out

class WAN21_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera",
"in_dim": 32,
}

def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
Expand Down Expand Up @@ -1129,6 +1139,6 @@ def get_model(self, state_dict, prefix="", device=None):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)

models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]

models += [SVD_img2vid]
Loading
Loading