diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 3289a840e2b1..c2d54e91750d 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers: - [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) - [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) - [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers) +- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers) > [!TIP] > Click on the Wan models in the right sidebar for more examples of video generation. @@ -249,6 +250,82 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color. + + + +### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication + +[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team. + +*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.* + +The example below demonstrates how to use the Wan-Animate pipeline to generate a video using a text description, a starting frame, a pose video, and a face video (optionally background video and mask video) in "animation" or "replacement" mode. + + + + +```python +import numpy as np +import torch +import torchvision.transforms.functional as TF +from diffusers import AutoencoderKLWan, WanAnimatePipeline +from diffusers.utils import export_to_video, load_image, load_video +from transformers import CLIPVisionModel + + +model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = WanAnimatePipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +# Preprocessing: The input video should be preprocessed into several materials before be feed into the inference process. +# TODO: Diffusersify the preprocessing process: !python wan/modules/animate/preprocess/preprocess_data.py + + +image = load_image("preprocessed_results/astronaut.jpg") +pose_video = load_video("preprocessed_results/pose_video.mp4") +face_video = load_video("preprocessed_results/face_video.mp4") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +def center_crop_resize(image, height, width): + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + image = TF.center_crop(image, size) + + return image, height, width + +image, height, width = aspect_ratio_resize(image, pipe) + +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + +#guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0): +# Classifier-free guidance scale. We only use it for expression control. +# In most cases, it's not necessary and faster generation can be achieved without it. +# When expression adjustments are needed, you may consider using this feature. +output = pipe( + image=image, pose_video=pose_video, face_video=face_video, prompt=prompt, height=height, width=width, guidance_scale=1.0 +).frames[0] +export_to_video(output, "output.mp4", fps=16) +``` + + + + ## Notes - Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. @@ -359,6 +436,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip - all - __call__ +## WanAnimatePipeline + +[[autodoc]] WanAnimatePipeline + - all + - __call__ + ## WanPipelineOutput [[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput \ No newline at end of file diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 39a364b07d78..e357826995c8 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -1,4 +1,5 @@ import argparse +import math import pathlib from typing import Any, Dict, Tuple @@ -6,11 +7,21 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import ( + AutoProcessor, + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionConfig, + CLIPVisionModel, + CLIPVisionModelWithProjection, + UMT5EncoderModel, +) from diffusers import ( AutoencoderKLWan, UniPCMultistepScheduler, + WanAnimatePipeline, + WanAnimateTransformer3DModel, WanImageToVideoPipeline, WanPipeline, WanTransformer3DModel, @@ -105,8 +116,114 @@ "after_proj": "proj_out", } +ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "cross_attn.k_img": "attn2.to_k_img", + "cross_attn.v_img": "attn2.to_v_img", + "cross_attn.norm_k_img": "attn2.norm_k_img", + # After cross_attn -> attn2 rename, we need to rename the img keys + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # Motion encoder mappings + "motion_encoder.enc.net_app.convs": "condition_embedder.motion_embedder.convs", + "motion_encoder.enc.fc": "condition_embedder.motion_embedder.linears", + "motion_encoder.dec.direction.weight": "condition_embedder.motion_embedder.motion_synthesis_weight", + # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten + "face_encoder.conv1_local.conv": "condition_embedder.face_embedder.conv1_local", + "face_encoder.conv2.conv": "condition_embedder.face_embedder.conv2", + "face_encoder.conv3.conv": "condition_embedder.face_embedder.conv3", + "face_encoder.out_proj": "condition_embedder.face_embedder.out_proj", + "face_encoder.norm1": "condition_embedder.face_embedder.norm1", + "face_encoder.norm2": "condition_embedder.face_embedder.norm2", + "face_encoder.norm3": "condition_embedder.face_embedder.norm3", + "face_encoder.padding_tokens": "condition_embedder.face_embedder.padding_tokens", + # Face adapter mappings + "face_adapter.fuser_blocks": "face_adapter", +} + + +def convert_equal_linear_weight(key: str, state_dict: Dict[str, Any]) -> None: + """ + Convert EqualLinear weights to standard Linear weights by applying the scale factor. + EqualLinear uses: F.linear(input, self.weight * self.scale, bias=self.bias) + where scale = (1 / sqrt(in_dim)) + """ + if ".weight" not in key: + return + + in_dim = state_dict[key].shape[1] + scale = 1.0 / math.sqrt(in_dim) + state_dict[key] = state_dict[key] * scale + + +def convert_equal_conv2d_weight(key: str, state_dict: Dict[str, Any]) -> None: + """ + Convert EqualConv2d weights to standard Conv2d weights by applying the scale factor. + EqualConv2d uses: F.conv2d(input, self.weight * self.scale, bias=self.bias, ...) + where scale = 1 / sqrt(in_channel * kernel_size^2) + """ + if ".weight" not in key or len(state_dict[key].shape) != 4: + return + + out_channel, in_channel, kernel_size, kernel_size = state_dict[key].shape + scale = 1.0 / math.sqrt(in_channel * kernel_size**2) + state_dict[key] = state_dict[key] * scale + + +def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any]) -> None: + """ + Convert all motion encoder weights for Animate model. + This handles both EqualLinear (in linears) and EqualConv2d (in convs). + + In the original model: + - All Linear layers in fc use EqualLinear + - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately) + - Blur kernels are stored as buffers in Sequential modules + - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)] + + Conversion strategy: + 1. Drop .kernel buffers (blur kernels) + 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu) + 3. Scale EqualLinear and EqualConv2d weights + """ + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {"condition_embedder.motion_embedder": convert_animate_motion_encoder_weights} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -364,6 +481,31 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-Animate-14B": + config = { + "model_id": "Wan-AI/Wan2.2-Animate-14B", + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "motion_encoder_dim": 512, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": (1, 2, 2), + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "rope_max_seq_len": 1024, + "pos_embed_seq_len": 257 * 2, + }, + } + RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP @@ -380,10 +522,12 @@ def convert_transformer(model_type: str, stage: str = None): original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): - if "VACE" not in model_type: - transformer = WanTransformer3DModel.from_config(diffusers_config) - else: + if "Animate" in model_type: + transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) + elif "VACE" in model_type: transformer = WanVACETransformer3DModel.from_config(diffusers_config) + else: + transformer = WanTransformer3DModel.from_config(diffusers_config) for key in list(original_state_dict.keys()): new_key = key[:] @@ -397,7 +541,24 @@ def convert_transformer(model_type: str, stage: str = None): continue handler_fn_inplace(key, original_state_dict) + # For Animate model, add blur_conv weights from the initialized model + # These are procedurally generated in the diffusers ConvLayer and not present in original checkpoint + if "Animate" in model_type: + # Create a temporary model on CPU to get the blur_conv weights + with torch.device("cpu"): + temp_transformer = WanAnimateTransformer3DModel.from_config(diffusers_config) + temp_model_state = temp_transformer.state_dict() + for key in temp_model_state.keys(): + if "blur_conv.weight" in key and "motion_embedder" in key: + original_state_dict[key] = temp_model_state[key] + del temp_transformer + + # Load state dict into the meta model, which will materialize the tensors transformer.load_state_dict(original_state_dict, strict=True, assign=True) + + # Move to CPU to ensure all tensors are materialized + transformer = transformer.to("cpu") + return transformer @@ -908,6 +1069,163 @@ def convert_vae_22(): return vae +def convert_openclip_xlm_roberta_vit_to_clip_vision_model(): + """ + Convert OpenCLIP XLM-RoBERTa-CLIP vision encoder to HuggingFace CLIPVisionModel format. + + The original checkpoint contains a multimodal XLM-RoBERTa-CLIP model with: + - Vision encoder: ViT-Huge/14 (1280 dim, 32 layers, 16 heads, patch_size=14) + - Text encoder: XLM-RoBERTa-Large (not used in Wan2.2-Animate) + + We extract only the vision encoder and convert it to CLIPVisionModel format. + + IMPORTANT: The original uses use_31_block=True (returns features from first 31 blocks only). + We convert only the first 31 layers to match this behavior exactly. + """ + # Download the OpenCLIP checkpoint + checkpoint_path = hf_hub_download( + "Wan-AI/Wan2.2-Animate-14B", "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" + ) + + # Load the checkpoint + openclip_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + # Create mapping from OpenCLIP vision encoder to CLIPVisionModel + # OpenCLIP uses "visual." prefix, we need to map to CLIPVisionModel structure + clip_vision_state_dict = {} + + # Mapping rules: + # visual.patch_embedding.weight -> vision_model.embeddings.patch_embedding.weight + # visual.patch_embedding.bias -> vision_model.embeddings.patch_embedding.bias + # visual.cls_embedding -> vision_model.embeddings.class_embedding + # visual.pos_embedding -> vision_model.embeddings.position_embedding.weight + # visual.transformer.{i}.norm1.weight -> vision_model.encoder.layers.{i}.layer_norm1.weight + # visual.transformer.{i}.norm1.bias -> vision_model.encoder.layers.{i}.layer_norm1.bias + # visual.transformer.{i}.attn.to_qkv.weight -> split into to_q, to_k, to_v + # visual.transformer.{i}.attn.proj.weight -> vision_model.encoder.layers.{i}.self_attn.out_proj.weight + # visual.transformer.{i}.norm2.weight -> vision_model.encoder.layers.{i}.layer_norm2.weight + # visual.transformer.{i}.mlp.0.weight -> vision_model.encoder.layers.{i}.mlp.fc1.weight + # visual.transformer.{i}.mlp.2.weight -> vision_model.encoder.layers.{i}.mlp.fc2.weight + # visual.pre_norm -> vision_model.pre_layrnorm (if exists) + # visual.post_norm -> vision_model.post_layernorm (if exists) + + for key, value in openclip_state_dict.items(): + if not key.startswith("visual."): + # Skip text encoder and other components + continue + + # Remove "visual." prefix + new_key = key[7:] # Remove "visual." + + # Embeddings + if new_key == "patch_embedding.weight": + clip_vision_state_dict["vision_model.embeddings.patch_embedding.weight"] = value + elif new_key == "patch_embedding.bias": + clip_vision_state_dict["vision_model.embeddings.patch_embedding.bias"] = value + elif new_key == "cls_embedding": + # Remove extra batch dimension: [1, 1, 1280] -> [1280] + clip_vision_state_dict["vision_model.embeddings.class_embedding"] = value.squeeze() + elif new_key == "pos_embedding": + # Remove extra batch dimension: [1, 257, 1280] -> [257, 1280] + clip_vision_state_dict["vision_model.embeddings.position_embedding.weight"] = value.squeeze(0) + + # Pre-norm (if exists) + elif new_key == "pre_norm.weight": + clip_vision_state_dict["vision_model.pre_layrnorm.weight"] = value + elif new_key == "pre_norm.bias": + clip_vision_state_dict["vision_model.pre_layrnorm.bias"] = value + + # Post-norm - final layer norm after transformer blocks + elif new_key == "post_norm.weight": + clip_vision_state_dict["vision_model.post_layernorm.weight"] = value + elif new_key == "post_norm.bias": + clip_vision_state_dict["vision_model.post_layernorm.bias"] = value + + # Transformer layers (only first 31 layers, skip layer 31 which is index 31) + elif new_key.startswith("transformer."): + parts = new_key.split(".") + if len(parts) >= 3: + layer_idx = int(parts[1]) + + # Skip the 32nd layer (index 31) to match use_31_block=True + if layer_idx >= 31: + continue + + component = ".".join(parts[2:]) + + # Layer norm 1 + if component == "norm1.weight": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.layer_norm1.weight"] = value + elif component == "norm1.bias": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.layer_norm1.bias"] = value + + # Attention - QKV split + elif component == "attn.to_qkv.weight": + # Split QKV into separate Q, K, V + qkv = value + q, k, v = qkv.chunk(3, dim=0) + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.q_proj.weight"] = q + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.k_proj.weight"] = k + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.v_proj.weight"] = v + elif component == "attn.to_qkv.bias": + # Split QKV bias + qkv_bias = value + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.q_proj.bias"] = q_bias + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.k_proj.bias"] = k_bias + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.v_proj.bias"] = v_bias + + # Attention output projection + elif component == "attn.proj.weight": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.out_proj.weight"] = ( + value + ) + elif component == "attn.proj.bias": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.self_attn.out_proj.bias"] = value + + # Layer norm 2 + elif component == "norm2.weight": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.layer_norm2.weight"] = value + elif component == "norm2.bias": + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.layer_norm2.bias"] = value + + # MLP + elif component.startswith("mlp.0."): + # First linear layer + mlp_component = component[6:] # Remove "mlp.0." + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.mlp.fc1.{mlp_component}"] = value + elif component.startswith("mlp.2."): + # Second linear layer (after activation) + mlp_component = component[6:] # Remove "mlp.2." + clip_vision_state_dict[f"vision_model.encoder.layers.{layer_idx}.mlp.fc2.{mlp_component}"] = value + + # Create CLIPVisionModel with matching config + # Use 31 layers to match the original use_31_block=True behavior + config = CLIPVisionConfig( + hidden_size=1280, + intermediate_size=5120, # 1280 * 4 (mlp_ratio) + num_hidden_layers=31, # Only first 31 layers, matching use_31_block=True + num_attention_heads=16, + image_size=224, + patch_size=14, + hidden_act="gelu", + layer_norm_eps=1e-5, + attention_dropout=0.0, + projection_dim=1024, # embed_dim from original config + ) + + with init_empty_weights(): + vision_model = CLIPVisionModel(config) + + # Load state dict into the meta model, which will materialize the tensors + vision_model.load_state_dict(clip_vision_state_dict, strict=True, assign=True) + + # Move to CPU to ensure all tensors are materialized + vision_model = vision_model.to("cpu") + + return vision_model + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_type", type=str, default=None) @@ -926,7 +1244,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type and "TI2V" not in args.model_type: + if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type: transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -942,7 +1260,7 @@ def get_args(): tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") if "FLF2V" in args.model_type: flow_shift = 16.0 - elif "TI2V" in args.model_type: + elif "TI2V" in args.model_type or "Animate" in args.model_type: flow_shift = 5.0 else: flow_shift = 3.0 @@ -954,6 +1272,8 @@ def get_args(): if args.dtype != "none": dtype = DTYPE_MAPPING[args.dtype] transformer.to(dtype) + if transformer_2 is not None: + transformer_2.to(dtype) if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type: pipe = WanImageToVideoPipeline( @@ -1016,6 +1336,34 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "Animate" in args.model_type: + # Convert OpenCLIP XLM-RoBERTa-CLIP vision encoder to CLIPVisionModel + print("Converting XLM-RoBERTa-CLIP vision encoder from OpenCLIP checkpoint...") + image_encoder = convert_openclip_xlm_roberta_vit_to_clip_vision_model() + + # Create image processor for ViT-Huge/14 with 224x224 images + image_processor = CLIPImageProcessor( + size={"shortest_edge": 224}, + crop_size={"height": 224, "width": 224}, + do_center_crop=True, + do_normalize=True, + do_rescale=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, # PIL.Image.BICUBIC + rescale_factor=0.00392156862745098, # 1/255 + ) + + pipe = WanAnimatePipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) else: pipe = WanPipeline( transformer=transformer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aa500b149441..a0f4e9fc2075 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -259,6 +259,7 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", + "WanAnimateTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", "attention_backend", @@ -620,6 +621,7 @@ "VisualClozeGenerationPipeline", "VisualClozePipeline", "VQDiffusionPipeline", + "WanAnimatePipeline", "WanImageToVideoPipeline", "WanPipeline", "WanVACEPipeline", @@ -952,6 +954,7 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, attention_backend, @@ -1283,6 +1286,7 @@ VisualClozeGenerationPipeline, VisualClozePipeline, VQDiffusionPipeline, + WanAnimatePipeline, WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 8d029bf5d31c..df5465fe3c0d 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -200,6 +201,7 @@ T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6b80ea6c82a5..0632a1ab4093 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,4 +37,5 @@ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel + from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py new file mode 100644 index 000000000000..704b6eab7672 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -0,0 +1,611 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin, get_parameter_dtype +from ..normalization import FP32LayerNorm +from .transformer_wan import ( + WanImageEmbedding, + WanRotaryPosEmbed, + WanTransformerBlock, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ConvLayer(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: int, + downsample: bool = False, + bias: bool = True, + activate: bool = True, + ): + super().__init__() + + self.downsample = downsample + self.activate = activate + + if activate: + self.act = nn.LeakyReLU(0.2) + self.bias_leaky_relu = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + + if downsample: + factor = 2 + blur_kernel = (1, 3, 3, 1) + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + # Create blur kernel + blur_kernel_tensor = torch.tensor(blur_kernel, dtype=torch.float32) + blur_kernel_2d = blur_kernel_tensor[None, :] * blur_kernel_tensor[:, None] + blur_kernel_2d /= blur_kernel_2d.sum() + + self.blur_conv = nn.Conv2d( + in_channel, + in_channel, + blur_kernel_2d.shape[0], + padding=(pad0, pad1), + groups=in_channel, + bias=False, + ) + + # Set the kernel weights + with torch.no_grad(): + # Expand kernel for groups + kernel_expanded = blur_kernel_2d.unsqueeze(0).unsqueeze(0).expand(in_channel, 1, -1, -1) + self.blur_conv.weight.copy_(kernel_expanded) + + stride = 2 + padding = 0 + else: + stride = 1 + padding = kernel_size // 2 + + self.conv2d = nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, bias=bias and not activate) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.downsample: + input = self.blur_conv(input) + + input = self.conv2d(input) + + if self.activate: + input = self.act(input + self.bias_leaky_relu) * 2**0.5 + + return input + + +class ResBlock(nn.Module): + def __init__(self, in_channel: int, out_channel: int): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class WanAnimateMotionEmbedder(nn.Module): + def __init__(self, size: int = 512, style_dim: int = 512, motion_dim: int = 20): + super().__init__() + + # Appearance encoder: conv layers + channels = {4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16} + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(nn.Conv2d(in_channel, style_dim, 4, padding=0, bias=False)) + + # Motion encoder: linear layers + linears = [] + for _ in range(4): + linears.append(nn.Linear(style_dim, style_dim)) + linears.append(nn.Linear(style_dim, motion_dim)) + self.linears = nn.Sequential(*linears) + + self.motion_synthesis_weight = nn.Parameter(torch.randn(512, 20)) + + def forward(self, face_image: torch.Tensor) -> torch.Tensor: + # Appearance encoding through convs + for conv in self.convs: + face_image = conv(face_image) + face_image = face_image.squeeze(-1).squeeze(-1) + + # Motion feature extraction + motion_feat = self.linears(face_image) + + # Motion synthesis via QR decomposition + weight = self.motion_synthesis_weight + 1e-8 + Q = torch.linalg.qr(weight.to(torch.float32))[0] + + input_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1).to(motion_feat.dtype) + return out + + +class WanAnimateFaceEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads: int, kernel_size: int = 3, eps: float = 1e-6): + super().__init__() + self.time_causal_padding = (kernel_size - 1, 0) + + self.conv1_local = nn.Conv1d(in_dim, 1024 * num_heads, kernel_size=kernel_size, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, eps, elementwise_affine=False) + self.act = nn.SiLU() + self.conv2 = nn.Conv1d(1024, 1024, kernel_size, stride=2) + self.conv3 = nn.Conv1d(1024, 1024, kernel_size, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(1024, eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(1024, eps, elementwise_affine=False) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = x.permute(0, 2, 1) + batch_size, channels, num_frames = x.shape + + x = F.pad(x, self.time_causal_padding, mode="replicate") + x = self.conv1_local(x) + x = x.unflatten(1, (-1, channels)).flatten(0, 1).permute(0, 2, 1) + + x = self.norm1(x) + x = self.act(x) + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode="replicate") + x = self.conv2(x) + x = x.permute(0, 2, 1) + x = self.norm2(x) + x = self.act(x) + x = x.permute(0, 2, 1) + x = F.pad(x, self.time_causal_padding, mode="replicate") + x = self.conv3(x) + x = x.permute(0, 2, 1) + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + +class WanTimeTextImageMotionFaceEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: int, + motion_encoder_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + self.image_embedder = WanImageEmbedding(image_embed_dim, dim) + self.motion_embedder = WanAnimateMotionEmbedder(size=512, style_dim=512, motion_dim=20) + self.face_embedder = WanAnimateFaceEmbedder(in_dim=motion_encoder_dim, hidden_dim=dim, num_heads=4) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + face_pixel_values: Optional[torch.Tensor] = None, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = get_parameter_dtype(self.time_embedder) + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + # Motion vector computation from face pixel values + batch_size, channels, num_face_frames, height, width = face_pixel_values.shape + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values_flat = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Extract motion features using motion embedder + motion_vec = self.motion_embedder(face_pixel_values_flat) + motion_vec = motion_vec.view(batch_size, num_face_frames, -1) + + # Encode motion vectors through face embedder + motion_vec = self.face_embedder(motion_vec) + + # Add padding at the beginning (prepend zeros) + batch_size, T_motion, N_motion, C_motion = motion_vec.shape + pad_face = torch.zeros(batch_size, 1, N_motion, C_motion, dtype=motion_vec.dtype, device=motion_vec.device) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, motion_vec + + +class WanAnimateFaceBlock(nn.Module): + _attention_backend = None + _parallel_config = None + + def __init__( + self, + hidden_size: int, + heads_num: int, + eps: float = 1e-6, + ): + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2) + self.linear1_q = nn.Linear(hidden_size, hidden_size) + self.linear2 = nn.Linear(hidden_size, hidden_size) + + self.q_norm = nn.RMSNorm(head_dim, eps) + self.k_norm = nn.RMSNorm(head_dim, eps) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, eps, elementwise_affine=False) + self.pre_norm_motion = nn.LayerNorm(hidden_size, eps, elementwise_affine=False) + + def set_attention_backend(self, backend): + """Set the attention backend for this face block.""" + self._attention_backend = backend + + def set_parallel_config(self, config): + """Set the parallel configuration for this face block.""" + self._parallel_config = config + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + ) -> torch.Tensor: + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = kv.view(B, T, N, 2, self.heads_num, -1).permute(3, 0, 1, 2, 4, 5) + q = q.unflatten(2, (self.heads_num, -1)) + + q = self.q_norm(q.float()).type_as(q) + k = self.k_norm(k.float()).type_as(k) + + k = k.flatten(0, 1) + v = v.flatten(0, 1) + + q = q.unflatten(1, (T_comp, -1)).flatten(0, 1) + + attn = dispatch_attention_fn( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + attn = attn.unflatten(0, (B, T_comp)).flatten(1, 2) + + output = self.linear2(attn) + + return output + + +class WanAnimateTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the WanAnimate model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `1280`): + The number of channels to use for the image embedding. If `None`, no projection is used. + added_kv_proj_dim (`int`, *optional*, defaults to `5120`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanAnimateTransformerBlock"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "motion_synthesis_weight", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 36, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = 1280, + added_kv_proj_dim: Optional[int] = 5120, + rope_max_seq_len: int = 1024, + motion_encoder_dim: int = 512, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.pose_patch_embedding = nn.Conv3d(16, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageMotionFaceEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + motion_encoder_dim=motion_encoder_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + self.face_adapter = nn.ModuleList( + [ + WanAnimateFaceBlock( + inner_dim, + num_attention_heads, + ) + for _ in range(num_layers // 5) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def set_attention_backend(self, backend: str): + """ + Set the attention backend for the transformer and all face adapter blocks. + + Args: + backend (`str`): The attention backend to use (e.g., 'flash', 'sdpa', 'xformers'). + """ + from ..attention_dispatch import AttentionBackendName + + # Validate backend + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend_enum = AttentionBackendName(backend.lower()) + + # Call parent ModelMixin method to set backend for all attention modules + super().set_attention_backend(backend) + + # Also set backend for all face adapter blocks (which use dispatch_attention_fn directly) + for face_block in self.face_adapter: + face_block.set_attention_backend(backend_enum) + + def set_parallel_config(self, config): + """ + Set the parallel configuration for all face adapter blocks. + + Args: + config: The parallel configuration to use. + """ + for face_block in self.face_adapter: + face_block.set_parallel_config(config) + + def forward( + self, + hidden_states: torch.Tensor, + pose_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + face_pixel_values: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + # Add pose embeddings to hidden states + hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states[:, :, 1:] + hidden_states = hidden_states.flatten(2).transpose(1, 2) + # sequence_length = int(math.ceil(np.prod([post_patch_num_frames, post_patch_height, post_patch_width]) // 4)) + # hidden_states = torch.cat([hidden_states, hidden_states.new_zeros(hidden_states.shape[0], sequence_length - hidden_states.shape[1], hidden_states.shape[2])], dim=1) + pose_hidden_states = pose_hidden_states.flatten(2).transpose(1, 2) + + # 3. Condition embeddings (time, text, image, motion) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, motion_vec = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, face_pixel_values + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # 4. Image embedding + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 5. Transformer blocks with face adapter integration + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block_idx, block in enumerate(self.blocks): + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + + # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...) + if block_idx % 5 == 0: + face_adapter_output = self.face_adapter[block_idx // 5](hidden_states, motion_vec) + hidden_states = face_adapter_output + hidden_states + else: + for block_idx, block in enumerate(self.blocks): + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...) + if block_idx % 5 == 0: + face_adapter_output = self.face_adapter[block_idx // 5](hidden_states, motion_vec) + hidden_states = face_adapter_output + hidden_states + + # 6. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c438caed571f..be7fea5156ef 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -381,7 +381,13 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] - _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["wan"] = [ + "WanPipeline", + "WanImageToVideoPipeline", + "WanVideoToVideoPipeline", + "WanVACEPipeline", + "WanAnimatePipeline", + ] _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", @@ -788,7 +794,13 @@ UniDiffuserTextDecoder, ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline - from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .wan import ( + WanAnimatePipeline, + WanImageToVideoPipeline, + WanPipeline, + WanVACEPipeline, + WanVideoToVideoPipeline, + ) from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py index bb96372b1db2..ad51a52f9242 100644 --- a/src/diffusers/pipelines/wan/__init__.py +++ b/src/diffusers/pipelines/wan/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_wan"] = ["WanPipeline"] + _import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"] _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] @@ -35,10 +36,10 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_wan import WanPipeline + from .pipeline_wan_animate import WanAnimatePipeline from .pipeline_wan_i2v import WanImageToVideoPipeline from .pipeline_wan_vace import WanVACEPipeline from .pipeline_wan_video2video import WanVideoToVideoPipeline - else: import sys diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py new file mode 100644 index 000000000000..fef436a5f46d --- /dev/null +++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -0,0 +1,1027 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, WanAnimatePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + >>> from transformers import CLIPVisionModel + + >>> model_id = "Wan-AI/Wan2.2-Animate-14B-720P-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanAnimatePipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> pose_video = load_video("path/to/pose_video.mp4") + >>> face_video = load_video("path/to/face_video.mp4") + >>> max_area = 480 * 832 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... image=image, + ... pose_video=pose_video, + ... face_video=face_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in these two + modes: + + 1. Animation mode: The model generates a video of the character image that mimics the human motion in the input + pose and face videos. + 2. Replacement mode: The model replaces the character image with the input video, using background and mask videos. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanAnimateTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + transformer: WanAnimateTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False + ) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + mode=None, + num_frames_for_temporal_guidance=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replacement" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video` undefined when mode is `replacement`." + ) + if mode == "replacement" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError( + "`background_video` and `mask_video` must be lists of PIL images when mode is `replacement`." + ) + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if mode is not None and (not isinstance(mode, str) or mode not in ("animation", "replacement")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animation', 'replacement') but its type is {type(mode)} and value is {mode}" + ) + + if num_frames_for_temporal_guidance is not None and ( + not isinstance(num_frames_for_temporal_guidance, int) or num_frames_for_temporal_guidance not in (1, 5) + ): + raise ValueError( + f"`num_frames_for_temporal_guidance` has to be of type `int` and 1 or 5 but its type is {type(num_frames_for_temporal_guidance)} and value is {num_frames_for_temporal_guidance}" + ) + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 80, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + conditioning_pixel_values: Optional[torch.Tensor] = None, + refer_t_pixel_values: Optional[torch.Tensor] = None, + background_pixel_values: Optional[torch.Tensor] = None, + mask_pixel_values: Optional[torch.Tensor] = None, + mask_reft_len: Optional[int] = None, + mode: Optional[str] = None, + y_ref: Optional[str] = None, + calculate_noise_latents_only: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_latent_frames = num_frames // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # Prepare latent normalization parameters + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + # The first outer loop + if mask_reft_len == 0: + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + image = image.to(device=device, dtype=self.vae.dtype) + # Encode conditioning (pose) video + conditioning_pixel_values = conditioning_pixel_values.to(device=device, dtype=self.vae.dtype) + + if isinstance(generator, list): + ref_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator] + ref_latents = torch.cat(ref_latents) + pose_latents = [ + retrieve_latents(self.vae.encode(conditioning_pixel_values), sample_mode="argmax") + for _ in generator + ] + pose_latents = torch.cat(pose_latents) + else: + ref_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + ref_latents = ref_latents.repeat(batch_size, 1, 1, 1, 1) + pose_latents = retrieve_latents(self.vae.encode(conditioning_pixel_values), sample_mode="argmax") + pose_latents = pose_latents.repeat(batch_size, 1, 1, 1, 1) + + ref_latents = (ref_latents.to(dtype) - latents_mean) * latents_std + pose_latents = (pose_latents.to(dtype) - latents_mean) * latents_std + + mask_ref = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, device) + y_ref = torch.concat([mask_ref, ref_latents], dim=1) + + refer_t_pixel_values = refer_t_pixel_values.to(self.vae.dtype) + background_pixel_values = background_pixel_values.to(self.vae.dtype) + + if mode == "replacement" and mask_pixel_values is not None: + mask_pixel_values = 1 - mask_pixel_values + mask_pixel_values = mask_pixel_values.flatten(0, 1) + mask_pixel_values = F.interpolate(mask_pixel_values, size=(latent_height, latent_width), mode="nearest") + mask_pixel_values = mask_pixel_values.unflatten(0, (-1, 1)) + + if mask_reft_len > 0 and not calculate_noise_latents_only: + if mode == "replacement": + y_reft = retrieve_latents( + self.vae.encode( + torch.concat( + [ + refer_t_pixel_values[:, :, :mask_reft_len], + background_pixel_values[:, :, mask_reft_len:], + ], + dim=2, + ) + ), + sample_mode="argmax", + ) + else: + y_reft = retrieve_latents( + self.vae.encode( + torch.concat( + [ + F.interpolate( + refer_t_pixel_values[:, :, :mask_reft_len], size=(height, width), mode="bicubic" + ), + torch.zeros( + batch_size, + 3, + num_frames - mask_reft_len, + height, + width, + device=device, + dtype=self.vae.dtype, + ), + ], + dim=2, + ) + ), + sample_mode="argmax", + ) + elif mask_reft_len == 0 and not calculate_noise_latents_only: + if mode == "replacement": + y_reft = retrieve_latents(self.vae.encode(background_pixel_values), sample_mode="argmax") + else: + y_reft = retrieve_latents( + self.vae.encode( + torch.zeros( + batch_size, + 3, + num_frames - mask_reft_len, + height, + width, + device=device, + dtype=self.vae.dtype, + ) + ), + sample_mode="argmax", + ) + + if mask_reft_len == 0 or not calculate_noise_latents_only: + y_reft = (y_reft.to(dtype) - latents_mean) * latents_std + msk_reft = self.get_i2v_mask( + batch_size, num_latent_frames, latent_height, latent_width, mask_reft_len, mask_pixel_values, device + ) + + y_reft = torch.concat([msk_reft, y_reft], dim=1) + condition = torch.concat([y_ref, y_reft], dim=2) + + if mask_reft_len == 0 and not calculate_noise_latents_only: + return latents, condition, pose_latents, y_ref, mask_pixel_values + elif mask_reft_len > 0 and not calculate_noise_latents_only: + return latents, condition + elif mask_reft_len > 0 and calculate_noise_latents_only: + return latents + + def get_i2v_mask( + self, batch_size, latent_t, latent_h, latent_w, mask_len=1, mask_pixel_values=None, device="cuda" + ): + if mask_pixel_values is None: + mask_lat_size = torch.zeros(batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, device=device) + else: + mask_lat_size = mask_pixel_values.clone() + mask_lat_size[:, :, :mask_len] = 1 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w + ).transpose(1, 2) + + return mask_lat_size + + def pad_video(self, frames, num_target_frames): + """ + pad_video([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + + return target_frames + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + pose_video: List[PIL.Image.Image], + face_video: List[PIL.Image.Image], + background_video: Optional[List[PIL.Image.Image]] = None, + mask_video: Optional[List[PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 80, + num_inference_steps: int = 50, + mode: str = "animation", + num_frames_for_temporal_guidance: int = 1, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input character image to condition the generation on. Must be an image, a list of images or a + `torch.Tensor`. + pose_video (`List[PIL.Image.Image]`): + The input pose video to condition the generation on. Must be a list of PIL images. + face_video (`List[PIL.Image.Image]`): + The input face video to condition the generation on. Must be a list of PIL images. + background_video (`List[PIL.Image.Image]`, *optional*): + When mode is `"replacement"`, the input background video to condition the generation on. Must be a list + of PIL images. + mask_video (`List[PIL.Image.Image]`, *optional*): + When mode is `"replacement"`, the input mask video to condition the generation on. Must be a list of + PIL images. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + mode (`str`, defaults to `"animation"`): + The mode of the generation. Choose between `"animation"` and `"replacement"`. + num_frames_for_temporal_guidance (`int`, defaults to `1`): + The number of frames used for temporal guidance. Recommended to be 1 or 5. + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `80`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + mode, + num_frames_for_temporal_guidance, + ) + + if num_frames % self.vae_scale_factor_temporal != 0: + logger.warning( + f"`num_frames` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # Encode image embedding + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # Calculate the number of valid frames + num_real_frames = len(pose_video) + real_clip_len = num_frames - num_frames_for_temporal_guidance + last_clip_num = (num_real_frames - num_frames_for_temporal_guidance) % real_clip_len + if last_clip_num == 0: + extra = 0 + else: + extra = real_clip_len - last_clip_num + num_target_frames = num_real_frames + extra + + pose_video = self.pad_video(pose_video, num_target_frames) + face_video = self.pad_video(face_video, num_target_frames) + if mode == "replacement": + background_video = self.pad_video(background_video, num_target_frames) + mask_video = self.pad_video(mask_video, num_target_frames) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + # Get dimensions from the first frame of pose_video (PIL Image.size returns (width, height)) + width, height = pose_video[0].size + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + face_video = self.video_processor.preprocess_video(face_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + if mode == "replacement": + background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + start = 0 + end = num_frames + all_out_frames = [] + out_frames = None + y_ref = None + calculate_noise_latents_only = False + + while True: + if start + num_frames_for_temporal_guidance >= len(pose_video): + break + + if start == 0: + mask_reft_len = 0 + else: + mask_reft_len = num_frames_for_temporal_guidance + + conditioning_pixel_values = pose_video[start:end] + face_pixel_values = face_video[start:end] + + if start == 0: + refer_t_pixel_values = torch.zeros(image.shape[0], 3, num_frames_for_temporal_guidance, height, width) + elif start > 0: + refer_t_pixel_values = ( + out_frames[0, :, -num_frames_for_temporal_guidance:].clone().detach().unsqueeze(0) + ) + + if mode == "replacement": + background_pixel_values = background_video[start:end] + mask_pixel_values = mask_video[start:end].permute(0, 2, 1, 3, 4) + else: + mask_pixel_values = None + background_pixel_values = None + + latents_outputs = self.prepare_latents( + image if start == 0 else None, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents if start == 0 else None, + conditioning_pixel_values, + refer_t_pixel_values, + background_pixel_values, + mask_pixel_values if not calculate_noise_latents_only else None, + mask_reft_len, + mode, + y_ref if start > 0 and not calculate_noise_latents_only else None, + calculate_noise_latents_only, + ) + # First iteration + if start == 0: + latents, condition, pose_latents, y_ref, mask_pixel_values = latents_outputs + # Second iteration + elif start > 0 and not calculate_noise_latents_only: + latents, condition = latents_outputs + calculate_noise_latents_only = True + # Subsequent iterations + else: + latents = latents_outputs + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + pose_hidden_states=pose_latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + face_pixel_values=face_pixel_values, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + # Blank out face for unconditional guidance (set all pixels to -1) + face_pixel_values_uncond = face_pixel_values * 0 - 1 + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + pose_hidden_states=pose_latents, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + face_pixel_values=face_pixel_values_uncond, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + x0 = latents + + x0 = x0.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(x0.device, x0.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + x0.device, x0.dtype + ) + x0 = x0 / latents_std + latents_mean + # Skip the first latent frame (used for conditioning) + out_frames = self.vae.decode(x0[:, :, 1:], return_dict=False)[0] + + if start > 0: + out_frames = out_frames[:, :, num_frames_for_temporal_guidance:] + all_out_frames.append(out_frames) + + start += num_frames - num_frames_for_temporal_guidance + end += num_frames - num_frames_for_temporal_guidance + + self._current_timestep = None + + if not output_type == "latent": + video = torch.cat(all_out_frames, dim=2)[:, :, :num_real_frames] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + # TODO + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 5d62709c28fd..e9bf9796b5b3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1488,6 +1488,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class WanAnimateTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class WanTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3244ef12ef87..3bbf8da89d98 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3407,6 +3407,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class WanAnimatePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/wan/test_wan_animate.py b/tests/pipelines/wan/test_wan_animate.py new file mode 100644 index 000000000000..aec3c0bff222 --- /dev/null +++ b/tests/pipelines/wan/test_wan_animate.py @@ -0,0 +1,245 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanAnimatePipeline, + WanAnimateTransformer3DModel, +) + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanAnimatePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanAnimateTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + pos_embed_seq_len=2 * (4 * 4 + 1), + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=4, + projection_dim=4, + num_hidden_layers=2, + num_attention_heads=2, + image_size=4, + intermediate_size=16, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + torch.manual_seed(0) + image_processor = CLIPImageProcessor(crop_size=4, size=4) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + num_frames = 17 + height = 16 + width = 16 + + pose_video = [Image.new("RGB", (height, width))] * num_frames + face_video = [Image.new("RGB", (height, width))] * num_frames + image = Image.new("RGB", (height, width)) + + inputs = { + "image": image, + "pose_video": pose_video, + "face_video": face_video, + "prompt": "dance monkey", + "negative_prompt": "negative", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": height, + "width": width, + "num_frames": num_frames, + "mode": "animation", + "num_frames_for_temporal_guidance": 1, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + def test_inference_with_single_reference_image(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["reference_images"] = Image.new("RGB", (16, 16)) + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + def test_inference_with_multiple_reference_image(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2] + video = pipe(**inputs).frames[0] + self.assertEqual(video.shape, (17, 3, 16, 16)) + + # fmt: off + expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983] + # fmt: on + + video_slice = video.flatten() + video_slice = torch.cat([video_slice[:8], video_slice[-8:]]) + video_slice = [round(x, 5) for x in video_slice.tolist()] + self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("Batching is not yet supported with this pipeline") + def test_inference_batch_single_identical(self): + return super().test_inference_batch_single_identical() + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_float16_inference(self): + pass + + @unittest.skip( + "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs" + ) + def test_save_load_float16(self): + pass diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0f4fd408a7c1..b42764be10d6 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -16,6 +16,7 @@ HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, ) @@ -721,6 +722,33 @@ def get_dummy_inputs(self): } +class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf" + torch_dtype = torch.bfloat16 + model_cls = WanAnimateTransformer3DModel + expected_memory_use_in_gb = 9 + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states": torch.randn( + (1, 96, 2, 64, 64), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "control_hidden_states_scale": torch.randn( + (8,), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + @require_torch_version_greater("2.7.1") class GGUFCompileTests(QuantCompileTests, unittest.TestCase): torch_dtype = torch.bfloat16