Skip to content

Commit c820ef9

Browse files
George0726Qirui Sun
andauthored
Add Wan-FUN Camera Control models and Add WanCameraImageToVideo node (Comfy-Org#8013)
* support wan camera models * fix by ruff check * change camera_condition type; make camera_condition optional * support camera trajectory nodes * fix camera direction --------- Co-authored-by: Qirui Sun <sunqr0667@126.com>
1 parent 6a2e4bb commit c820ef9

File tree

6 files changed

+431
-1
lines changed

6 files changed

+431
-1
lines changed

comfy/ldm/wan/model.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,60 @@ def forward(self, c, x, **kwargs):
247247
return c_skip, c
248248

249249

250+
class WanCamAdapter(nn.Module):
251+
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
252+
super(WanCamAdapter, self).__init__()
253+
254+
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
255+
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
256+
257+
# Convolution: reduce spatial dimensions by a factor
258+
# of 2 (without overlap)
259+
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
260+
261+
# Residual blocks for feature extraction
262+
self.residual_blocks = nn.Sequential(
263+
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
264+
)
265+
266+
def forward(self, x):
267+
# Reshape to merge the frame dimension into batch
268+
bs, c, f, h, w = x.size()
269+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
270+
271+
# Pixel Unshuffle operation
272+
x_unshuffled = self.pixel_unshuffle(x)
273+
274+
# Convolution operation
275+
x_conv = self.conv(x_unshuffled)
276+
277+
# Feature extraction with residual blocks
278+
out = self.residual_blocks(x_conv)
279+
280+
# Reshape to restore original bf dimension
281+
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
282+
283+
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
284+
out = out.permute(0, 2, 1, 3, 4)
285+
286+
return out
287+
288+
289+
class WanCamResidualBlock(nn.Module):
290+
def __init__(self, dim, operation_settings={}):
291+
super(WanCamResidualBlock, self).__init__()
292+
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
293+
self.relu = nn.ReLU(inplace=True)
294+
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
295+
296+
def forward(self, x):
297+
residual = x
298+
out = self.relu(self.conv1(x))
299+
out = self.conv2(out)
300+
out += residual
301+
return out
302+
303+
250304
class Head(nn.Module):
251305

252306
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@@ -637,3 +691,92 @@ def block_wrap(args):
637691
# unpatchify
638692
x = self.unpatchify(x, grid_sizes)
639693
return x
694+
695+
class CameraWanModel(WanModel):
696+
r"""
697+
Wan diffusion backbone supporting both text-to-video and image-to-video.
698+
"""
699+
700+
def __init__(self,
701+
model_type='camera',
702+
patch_size=(1, 2, 2),
703+
text_len=512,
704+
in_dim=16,
705+
dim=2048,
706+
ffn_dim=8192,
707+
freq_dim=256,
708+
text_dim=4096,
709+
out_dim=16,
710+
num_heads=16,
711+
num_layers=32,
712+
window_size=(-1, -1),
713+
qk_norm=True,
714+
cross_attn_norm=True,
715+
eps=1e-6,
716+
flf_pos_embed_token_number=None,
717+
image_model=None,
718+
in_dim_control_adapter=24,
719+
device=None,
720+
dtype=None,
721+
operations=None,
722+
):
723+
724+
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
725+
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
726+
727+
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
728+
729+
730+
def forward_orig(
731+
self,
732+
x,
733+
t,
734+
context,
735+
clip_fea=None,
736+
freqs=None,
737+
camera_conditions = None,
738+
transformer_options={},
739+
**kwargs,
740+
):
741+
# embeddings
742+
x = self.patch_embedding(x.float()).to(x.dtype)
743+
if self.control_adapter is not None and camera_conditions is not None:
744+
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
745+
x = x + x_camera
746+
grid_sizes = x.shape[2:]
747+
x = x.flatten(2).transpose(1, 2)
748+
749+
# time embeddings
750+
e = self.time_embedding(
751+
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
752+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
753+
754+
# context
755+
context = self.text_embedding(context)
756+
757+
context_img_len = None
758+
if clip_fea is not None:
759+
if self.img_emb is not None:
760+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
761+
context = torch.concat([context_clip, context], dim=1)
762+
context_img_len = clip_fea.shape[-2]
763+
764+
patches_replace = transformer_options.get("patches_replace", {})
765+
blocks_replace = patches_replace.get("dit", {})
766+
for i, block in enumerate(self.blocks):
767+
if ("double_block", i) in blocks_replace:
768+
def block_wrap(args):
769+
out = {}
770+
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
771+
return out
772+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
773+
x = out["img"]
774+
else:
775+
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
776+
777+
# head
778+
x = self.head(x, e)
779+
780+
# unpatchify
781+
x = self.unpatchify(x, grid_sizes)
782+
return x

comfy/model_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,17 @@ def extra_conds(self, **kwargs):
10791079
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
10801080
return out
10811081

1082+
class WAN21_Camera(WAN21):
1083+
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
1084+
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
1085+
self.image_to_video = image_to_video
1086+
1087+
def extra_conds(self, **kwargs):
1088+
out = super().extra_conds(**kwargs)
1089+
camera_conditions = kwargs.get("camera_conditions", None)
1090+
if camera_conditions is not None:
1091+
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
1092+
return out
10821093

10831094
class Hunyuan3Dv2(BaseModel):
10841095
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

comfy/supported_models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,16 @@ def get_model(self, state_dict, prefix="", device=None):
992992
out = model_base.WAN21(self, image_to_video=False, device=device)
993993
return out
994994

995+
class WAN21_Camera(WAN21_T2V):
996+
unet_config = {
997+
"image_model": "wan2.1",
998+
"model_type": "i2v",
999+
"in_dim": 32,
1000+
}
1001+
1002+
def get_model(self, state_dict, prefix="", device=None):
1003+
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
1004+
return out
9951005
class WAN21_Vace(WAN21_T2V):
9961006
unet_config = {
9971007
"image_model": "wan2.1",
@@ -1129,6 +1139,6 @@ def get_model(self, state_dict, prefix="", device=None):
11291139
def clip_target(self, state_dict={}):
11301140
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
11311141

1132-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
1142+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
11331143

11341144
models += [SVD_img2vid]

0 commit comments

Comments
 (0)