diff --git a/examples/flf2v_input_first_frame.png b/examples/flf2v_input_first_frame.png
new file mode 100644
index 0000000..032cd5c
Binary files /dev/null and b/examples/flf2v_input_first_frame.png differ
diff --git a/examples/flf2v_input_last_frame.png b/examples/flf2v_input_last_frame.png
new file mode 100644
index 0000000..83ac8c5
Binary files /dev/null and b/examples/flf2v_input_last_frame.png differ
diff --git a/generate.py b/generate.py
index b348c3e..e85ef91 100644
--- a/generate.py
+++ b/generate.py
@@ -43,6 +43,14 @@
"image":
"examples/i2v_input.JPG",
},
+ "flf2v-14B": {
+ "prompt":
+ "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
+ "first_frame":
+ "examples/flf2v_input_first_frame.png",
+ "last_frame":
+ "examples/flf2v_input_last_frame.png",
+ },
}
@@ -60,6 +68,8 @@ def _validate_args(args):
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
+ if "flf2v" in args.task:
+ args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
@@ -187,7 +197,17 @@ def _parse_args():
"--base_seed",
type=int,
default=-1,
- help="The seed to use for generating the image or video.")
+ help="[image to video] The image to generate the video from.")
+ parser.add_argument(
+ "--first_frame",
+ type=str,
+ default=None,
+ help="[first-last frame to video] The image (first frame) to generate the video from.")
+ parser.add_argument(
+ "--last_frame",
+ type=str,
+ default=None,
+ help="[first-last frame to video] The image (last frame) to generate the video from.")
parser.add_argument(
"--image",
type=str,
@@ -305,7 +325,7 @@ def generate(args):
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
- model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
+ model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
@@ -440,7 +460,7 @@ def generate(args):
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")
- else:
+ elif "i2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
@@ -561,7 +581,112 @@ def generate(args):
stream.synchronize()
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")
+ else:
+ if args.prompt is None:
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
+ if args.first_frame is None or args.last_frame is None:
+ args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"]
+ args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"]
+ logging.info(f"Input prompt: {args.prompt}")
+ logging.info(f"Input first frame: {args.first_frame}")
+ logging.info(f"Input last frame: {args.last_frame}")
+ first_frame = Image.open(args.first_frame).convert("RGB")
+ last_frame = Image.open(args.last_frame).convert("RGB")
+ if args.use_prompt_extend:
+ logging.info("Extending prompt ...")
+ if rank == 0:
+ prompt_output = prompt_expander(
+ args.prompt,
+ tar_lang=args.prompt_extend_target_lang,
+ image=[first_frame, last_frame],
+ seed=args.base_seed)
+ if prompt_output.status == False:
+ logging.info(
+ f"Extending prompt failed: {prompt_output.message}")
+ logging.info("Falling back to original prompt.")
+ input_prompt = args.prompt
+ else:
+ input_prompt = prompt_output.prompt
+ input_prompt = [input_prompt]
+ else:
+ input_prompt = [None]
+ if dist.is_initialized():
+ dist.broadcast_object_list(input_prompt, src=0)
+ args.prompt = input_prompt[0]
+ logging.info(f"Extended prompt: {args.prompt}")
+
+ logging.info("Creating WanFLF2V pipeline.")
+ wan_flf2v = wan.WanFLF2V(
+ config=cfg,
+ checkpoint_dir=args.ckpt_dir,
+ device_id=device,
+ rank=rank,
+ t5_fsdp=args.t5_fsdp,
+ dit_fsdp=args.dit_fsdp,
+ use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
+ t5_cpu=args.t5_cpu,
+ use_vae_parallel=args.vae_parallel,
+ )
+
+ transformer = wan_flf2v.model
+ if args.use_attentioncache:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.blocks),
+ steps_count=args.sample_steps,
+ step_start=args.start_step,
+ step_interval=args.attentioncache_interval,
+ step_end=args.end_step
+ )
+ else:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.blocks),
+ steps_count=args.sample_steps
+ )
+ cache = CacheAgent(config)
+ if args.dit_fsdp:
+ for block in transformer._fsdp_wrapped_module.blocks:
+ block._fsdp_wrapped_module.cache = cache
+ block._fsdp_wrapped_module.args = args
+ else:
+ for block in transformer.blocks:
+ block.cache = cache
+ block.args = args
+
+ logging.info(f"Warm up 2 steps...")
+ video = wan_flf2v.generate(
+ args.prompt,
+ first_frame,
+ last_frame,
+ max_area=MAX_AREA_CONFIGS[args.size],
+ frame_num=args.frame_num,
+ shift=args.sample_shift,
+ sample_solver=args.sample_solver,
+ sampling_steps=2,
+ guide_scale=args.sample_guide_scale,
+ seed=args.base_seed,
+ offload_model=args.offload_model)
+
+ logging.info("Generating video ...")
+ stream.synchronize()
+ begin = time.time()
+ video = wan_flf2v.generate(
+ args.prompt,
+ first_frame,
+ last_frame,
+ max_area=MAX_AREA_CONFIGS[args.size],
+ frame_num=args.frame_num,
+ shift=args.sample_shift,
+ sample_solver=args.sample_solver,
+ sampling_steps=args.sample_steps,
+ guide_scale=args.sample_guide_scale,
+ seed=args.base_seed,
+ offload_model=args.offload_model)
+ stream.synchronize()
+ end = time.time()
+ logging.info(f"Generating video used time {end - begin: .4f}s")
if rank == 0:
if args.save_file is None:
diff --git a/gradio/fl2v_14B_singleGPU.py b/gradio/fl2v_14B_singleGPU.py
new file mode 100644
index 0000000..c55ed0c
--- /dev/null
+++ b/gradio/fl2v_14B_singleGPU.py
@@ -0,0 +1,254 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import argparse
+import gc
+import os
+import os.path as osp
+import sys
+import warnings
+
+import gradio as gr
+
+warnings.filterwarnings('ignore')
+
+# Model
+sys.path.insert(
+ 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
+import wan
+from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
+from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
+from wan.utils.utils import cache_video
+
+# Global Var
+prompt_expander = None
+wan_flf2v_720P = None
+
+
+# Button Func
+def load_model(value):
+ global wan_flf2v_720P
+
+ if value == '------':
+ print("No model loaded")
+ return '------'
+
+ if value == '720P':
+ if args.ckpt_dir_720p is None:
+ print("Please specify the checkpoint directory for 720P model")
+ return '------'
+ if wan_flf2v_720P is not None:
+ pass
+ else:
+ gc.collect()
+
+ print("load 14B-720P flf2v model...", end='', flush=True)
+ cfg = WAN_CONFIGS['flf2v-14B']
+ wan_flf2v_720P = wan.WanFLF2V(
+ config=cfg,
+ checkpoint_dir=args.ckpt_dir_720p,
+ device_id=0,
+ rank=0,
+ t5_fsdp=False,
+ dit_fsdp=False,
+ use_usp=False,
+ )
+ print("done", flush=True)
+ return '720P'
+ return value
+
+
+def prompt_enc(prompt, img_first, img_last, tar_lang):
+ print('prompt extend...')
+ if img_first is None or img_last is None:
+ print('Please upload the first and last frames')
+ return prompt
+ global prompt_expander
+ prompt_output = prompt_expander(
+ prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
+ if prompt_output.status == False:
+ return prompt
+ else:
+ return prompt_output.prompt
+
+
+def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
+ resolution, sd_steps, guide_scale, shift_scale, seed,
+ n_prompt):
+
+ if resolution == '------':
+ print(
+ 'Please specify the resolution ckpt dir or specify the resolution')
+ return None
+
+ else:
+ if resolution == '720P':
+ global wan_flf2v_720P
+ video = wan_flf2v_720P.generate(
+ flf2vid_prompt,
+ flf2vid_image_first,
+ flf2vid_image_last,
+ max_area=MAX_AREA_CONFIGS['720*1280'],
+ shift=shift_scale,
+ sampling_steps=sd_steps,
+ guide_scale=guide_scale,
+ n_prompt=n_prompt,
+ seed=seed,
+ offload_model=True)
+ pass
+ else:
+ print('Sorry, currently only 720P is supported.')
+ return None
+
+ cache_video(
+ tensor=video[None],
+ save_file="example.mp4",
+ fps=16,
+ nrow=1,
+ normalize=True,
+ value_range=(-1, 1))
+
+ return "example.mp4"
+
+
+# Interface
+def gradio_interface():
+ with gr.Blocks() as demo:
+ gr.Markdown("""
+
+ Wan2.1 (FLF2V-14B)
+
+
+ Wan: Open and Advanced Large-Scale Video Generative Models.
+
+ """)
+
+ with gr.Row():
+ with gr.Column():
+ resolution = gr.Dropdown(
+ label='Resolution',
+ choices=['------', '720P'],
+ value='------')
+ flf2vid_image_first = gr.Image(
+ type="pil",
+ label="Upload First Frame",
+ elem_id="image_upload",
+ )
+ flf2vid_image_last = gr.Image(
+ type="pil",
+ label="Upload Last Frame",
+ elem_id="image_upload",
+ )
+ flf2vid_prompt = gr.Textbox(
+ label="Prompt",
+ placeholder="Describe the video you want to generate",
+ )
+ tar_lang = gr.Radio(
+ choices=["ZH", "EN"],
+ label="Target language of prompt enhance",
+ value="ZH")
+ run_p_button = gr.Button(value="Prompt Enhance")
+
+ with gr.Accordion("Advanced Options", open=True):
+ with gr.Row():
+ sd_steps = gr.Slider(
+ label="Diffusion steps",
+ minimum=1,
+ maximum=1000,
+ value=50,
+ step=1)
+ guide_scale = gr.Slider(
+ label="Guide scale",
+ minimum=0,
+ maximum=20,
+ value=5.0,
+ step=1)
+ with gr.Row():
+ shift_scale = gr.Slider(
+ label="Shift scale",
+ minimum=0,
+ maximum=20,
+ value=5.0,
+ step=1)
+ seed = gr.Slider(
+ label="Seed",
+ minimum=-1,
+ maximum=2147483647,
+ step=1,
+ value=-1)
+ n_prompt = gr.Textbox(
+ label="Negative Prompt",
+ placeholder="Describe the negative prompt you want to add"
+ )
+
+ run_flf2v_button = gr.Button("Generate Video")
+
+ with gr.Column():
+ result_gallery = gr.Video(
+ label='Generated Video', interactive=False, height=600)
+
+ resolution.input(
+ fn=load_model, inputs=[resolution], outputs=[resolution])
+
+ run_p_button.click(
+ fn=prompt_enc,
+ inputs=[
+ flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
+ tar_lang
+ ],
+ outputs=[flf2vid_prompt])
+
+ run_flf2v_button.click(
+ fn=flf2v_generation,
+ inputs=[
+ flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
+ resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
+ ],
+ outputs=[result_gallery],
+ )
+
+ return demo
+
+
+# Main
+def _parse_args():
+ parser = argparse.ArgumentParser(
+ description="Generate a video from a text prompt or image using Gradio")
+ parser.add_argument(
+ "--ckpt_dir_720p",
+ type=str,
+ default=None,
+ help="The path to the checkpoint directory.")
+ parser.add_argument(
+ "--prompt_extend_method",
+ type=str,
+ default="local_qwen",
+ choices=["dashscope", "local_qwen"],
+ help="The prompt extend method to use.")
+ parser.add_argument(
+ "--prompt_extend_model",
+ type=str,
+ default=None,
+ help="The prompt extend model to use.")
+
+ args = parser.parse_args()
+ assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."
+
+ return args
+
+
+if __name__ == '__main__':
+ args = _parse_args()
+
+ print("Step1: Init prompt_expander...", end='', flush=True)
+ if args.prompt_extend_method == "dashscope":
+ prompt_expander = DashScopePromptExpander(
+ model_name=args.prompt_extend_model, is_vl=True)
+ elif args.prompt_extend_method == "local_qwen":
+ prompt_expander = QwenPromptExpander(
+ model_name=args.prompt_extend_model, is_vl=True, device=0)
+ else:
+ raise NotImplementedError(
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
+ print("done", flush=True)
+
+ demo = gradio_interface()
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
diff --git a/tests/flf2v_test.sh b/tests/flf2v_test.sh
new file mode 100644
index 0000000..0cddccc
--- /dev/null
+++ b/tests/flf2v_test.sh
@@ -0,0 +1,24 @@
+GPUS=8
+PY_FILE="../generate.py"
+FLF2V_14B_14B_CKPT_DIR=Wan-AI/Wan2.1-FLF2V-14B-720P
+# export CPLUS_INCLUDE_PATH=/usr/include/c++/12/:/usr/include/c++/12/aarch64-openEuler-linux/:$CPLUS_INCLUDE_PATH
+
+
+export ALGO=1
+export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True'
+export TASK_QUEUE_ENABLE=2
+export CPU_AFFINITY_CONF=1
+export TOKENIZERS_PARALLELISM=false
+
+torchrun \
+ --master_port=23152 \
+ --nproc_per_node=$GPUS $PY_FILE \
+ --task flf2v-14B \
+ --ckpt_dir $FLF2V_14B_14B_CKPT_DIR \
+ --size 960*960 \
+ --dit_fsdp \
+ --t5_fsdp \
+ --ulysses_size $GPUS \
+ --first_frame ../examples/flf2v_input_first_frame.png \
+ --last_frame ../examples/flf2v_input_last_frame.png \
+ --prompt "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。"
\ No newline at end of file
diff --git a/tests/test.sh b/tests/test.sh
index bf40cd7..714f153 100644
--- a/tests/test.sh
+++ b/tests/test.sh
@@ -105,9 +105,22 @@ function i2v_14B_720p() {
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
+function flf2v_14B_720p() {
+ I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
+
+ # 1-GPU Test
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
+ python $PY_FILE --task flf2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
+
+ # Multiple GPU Test
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task flf2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
+}
+
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
+flf2v_14B_720p
\ No newline at end of file
diff --git a/wan/__init__.py b/wan/__init__.py
index df36ebe..9f1fec2 100644
--- a/wan/__init__.py
+++ b/wan/__init__.py
@@ -1,3 +1,4 @@
from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
+from .first_last_frame2video import WanFLF2V
\ No newline at end of file
diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py
index 0b3f359..7773fc5 100644
--- a/wan/configs/__init__.py
+++ b/wan/configs/__init__.py
@@ -12,11 +12,17 @@
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'
+# the config of flf2v_14B is the same as i2v_14B
+flf2v_14B = copy.deepcopy(i2v_14B)
+flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
+flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
+
WAN_CONFIGS = {
't2v-14B': t2v_14B,
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
+ 'flf2v-14B': flf2v_14B
}
SIZE_CONFIGS = {
@@ -27,6 +33,7 @@
'480*720': (480, 720),
'720*480': (720, 480),
'1024*1024': (1024, 1024),
+ '960*960': (960, 960),
}
MAX_AREA_CONFIGS = {
@@ -36,11 +43,14 @@
'832*480': 832 * 480,
'480*720': 480 * 720,
'720*480': 720 * 480,
+ '1024*1024': 1024 * 1024,
+ '960*960': 960 * 960,
}
SUPPORTED_SIZES = {
- 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*720', '720*480'),
+ 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*720', '720*480', '1024*1024'),
't2v-1.3B': ('480*832', '832*480', '480*720', '720*480'),
- 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*720', '720*480'),
+ 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*720', '720*480', '1024*1024'),
+ 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '960*960'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
}
diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py
index 12e8e20..53bf221 100644
--- a/wan/configs/wan_i2v_14B.py
+++ b/wan/configs/wan_i2v_14B.py
@@ -8,6 +8,7 @@
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg)
+i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py
index 1c56b2b..7a4aedc 100644
--- a/wan/distributed/xdit_context_parallel.py
+++ b/wan/distributed/xdit_context_parallel.py
@@ -27,21 +27,14 @@ def pad_freqs(original_tensor, target_len):
return padded_tensor
-@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs_list):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
- s, n, c = x.size(1), x.size(2), x.size(3)
- output = []
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
- x_i = x[i, :s].reshape(1, s, n, c)
- cos, sin = freqs_list[i]
- x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True)
- output.append(x_i)
- return torch.cat(output).float()
+ cos, sin = freqs_list[0]
+ return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
def usp_dit_forward(
diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py
new file mode 100644
index 0000000..fc7d430
--- /dev/null
+++ b/wan/first_last_frame2video.py
@@ -0,0 +1,425 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import gc
+import logging
+import math
+import os
+import random
+import sys
+import types
+from contextlib import contextmanager
+from functools import partial
+
+import numpy as np
+import torch
+import torch.cuda.amp as amp
+import torch.distributed as dist
+import torchvision.transforms.functional as TF
+from tqdm import tqdm
+
+from .distributed.fsdp import shard_model
+from .modules.clip import CLIPModel
+from .modules.model import WanModel
+from .modules.t5 import T5EncoderModel
+from .modules.vae import WanVAE
+from .utils.fm_solvers import (
+ FlowDPMSolverMultistepScheduler,
+ get_sampling_sigmas,
+ retrieve_timesteps,
+)
+from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
+from .vae_patch_parallel import VAE_patch_parallel, set_vae_patch_parallel
+from wan.distributed.parallel_mgr import (
+ get_sequence_parallel_world_size,
+ get_classifier_free_guidance_world_size,
+ get_classifier_free_guidance_rank,
+ get_cfg_group,
+)
+
+
+class WanFLF2V:
+
+ def __init__(
+ self,
+ config,
+ checkpoint_dir,
+ device_id=0,
+ rank=0,
+ t5_fsdp=False,
+ dit_fsdp=False,
+ use_usp=False,
+ t5_cpu=False,
+ init_on_cpu=True,
+ use_vae_parallel=False,
+ ):
+ r"""
+ Initializes the image-to-video generation model components.
+
+ Args:
+ config (EasyDict):
+ Object containing model parameters initialized from config.py
+ checkpoint_dir (`str`):
+ Path to directory containing model checkpoints
+ device_id (`int`, *optional*, defaults to 0):
+ Id of target GPU device
+ rank (`int`, *optional*, defaults to 0):
+ Process rank for distributed training
+ t5_fsdp (`bool`, *optional*, defaults to False):
+ Enable FSDP sharding for T5 model
+ dit_fsdp (`bool`, *optional*, defaults to False):
+ Enable FSDP sharding for DiT model
+ use_usp (`bool`, *optional*, defaults to False):
+ Enable distribution strategy of USP.
+ t5_cpu (`bool`, *optional*, defaults to False):
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
+ init_on_cpu (`bool`, *optional*, defaults to True):
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
+ """
+ self.device = torch.device(f"cuda:{device_id}")
+ self.config = config
+ self.rank = rank
+ self.use_usp = use_usp
+ self.t5_cpu = t5_cpu
+
+ self.num_train_timesteps = config.num_train_timesteps
+ self.param_dtype = config.param_dtype
+
+ shard_fn = partial(shard_model, device_id=device_id)
+ self.text_encoder = T5EncoderModel(
+ text_len=config.text_len,
+ dtype=config.t5_dtype,
+ device=self.device,
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
+ shard_fn=shard_fn if t5_fsdp else None,
+ )
+
+ self.vae_stride = config.vae_stride
+ self.patch_size = config.patch_size
+ self.vae = WanVAE(
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
+ device=self.device,
+ dtype=self.param_dtype)
+ if use_vae_parallel:
+ all_pp_group_ranks = []
+ if dist.get_world_size() < 8 :
+ all_pp_group_ranks.append(list(range(0, dist.get_world_size())))
+ set_vae_patch_parallel(self.vae.model, dist.get_world_size(), 1, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="decoder.forward")
+ set_vae_patch_parallel(self.vae.model, dist.get_world_size(), 1, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="encoder.forward")
+ else:
+ for i in range(0, dist.get_world_size() // 8):
+ all_pp_group_ranks.append(list(range(8 * i, 8 * (i + 1))))
+ set_vae_patch_parallel(self.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="decoder.forward")
+ set_vae_patch_parallel(self.vae.model, 4, 2, all_pp_group_ranks= all_pp_group_ranks, decoder_decode="encoder.forward")
+
+
+ self.clip = CLIPModel(
+ dtype=config.clip_dtype,
+ device=self.device,
+ checkpoint_path=os.path.join(checkpoint_dir,
+ config.clip_checkpoint),
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
+
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
+ self.model = WanModel.from_pretrained(checkpoint_dir, torch_dtype=self.param_dtype)
+ self.model.eval().requires_grad_(False)
+
+ if t5_fsdp or dit_fsdp or use_usp:
+ init_on_cpu = False
+
+ if use_usp:
+
+ from .distributed.xdit_context_parallel import (
+ usp_attn_forward,
+ usp_dit_forward,
+ )
+ for block in self.model.blocks:
+ block.self_attn.forward = types.MethodType(
+ usp_attn_forward, block.self_attn)
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
+ self.sp_size = get_sequence_parallel_world_size()
+ else:
+ self.sp_size = 1
+
+ if dist.is_initialized():
+ dist.barrier()
+ if dit_fsdp:
+ self.model = shard_fn(self.model)
+ else:
+ if not init_on_cpu:
+ self.model.to(self.device)
+
+ self.sample_neg_prompt = config.sample_neg_prompt
+
+ def generate(self,
+ input_prompt,
+ first_frame,
+ last_frame,
+ max_area=720 * 1280,
+ frame_num=81,
+ shift=16,
+ sample_solver='unipc',
+ sampling_steps=50,
+ guide_scale=5.5,
+ n_prompt="",
+ seed=-1,
+ offload_model=True):
+ r"""
+ Generates video frames from input first-last frame and text prompt using diffusion process.
+
+ Args:
+ input_prompt (`str`):
+ Text prompt for content generation.
+ first_frame (PIL.Image.Image):
+ Input image tensor. Shape: [3, H, W]
+ last_frame (PIL.Image.Image):
+ Input image tensor. Shape: [3, H, W]
+ [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
+ to match first_frame.
+ max_area (`int`, *optional*, defaults to 720*1280):
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
+ frame_num (`int`, *optional*, defaults to 81):
+ How many frames to sample from a video. The number should be 4n+1
+ shift (`float`, *optional*, defaults to 5.0):
+ Noise schedule shift parameter. Affects temporal dynamics
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
+ Solver used to sample the video.
+ sampling_steps (`int`, *optional*, defaults to 40):
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
+ guide_scale (`float`, *optional*, defaults 5.0):
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
+ n_prompt (`str`, *optional*, defaults to ""):
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
+ seed (`int`, *optional*, defaults to -1):
+ Random seed for noise generation. If -1, use random seed
+ offload_model (`bool`, *optional*, defaults to True):
+ If True, offloads models to CPU during generation to save VRAM
+
+ Returns:
+ torch.Tensor:
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
+ - C: Color channels (3 for RGB)
+ - N: Number of frames (81)
+ - H: Frame height (from max_area)
+ - W: Frame width from max_area)
+ """
+ first_frame_size = first_frame.size
+ last_frame_size = last_frame.size
+ first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
+ self.device)
+ last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
+ self.device)
+
+ F = frame_num
+ first_frame_h, first_frame_w = first_frame.shape[1:]
+ aspect_ratio = first_frame_h / first_frame_w
+ lat_h = round(
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
+ self.patch_size[1] * self.patch_size[1])
+ lat_w = round(
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
+ self.patch_size[2] * self.patch_size[2])
+ first_frame_h = lat_h * self.vae_stride[1]
+ first_frame_w = lat_w * self.vae_stride[2]
+ if first_frame_size != last_frame_size:
+ # 1. resize
+ last_frame_resize_ratio = max(
+ first_frame_size[0] / last_frame_size[0],
+ first_frame_size[1] / last_frame_size[1])
+ last_frame_size = [
+ round(last_frame_size[0] * last_frame_resize_ratio),
+ round(last_frame_size[1] * last_frame_resize_ratio),
+ ]
+ # 2. center crop
+ last_frame = TF.center_crop(last_frame, last_frame_size)
+
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
+ self.patch_size[1] * self.patch_size[2])
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
+
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
+ seed_g = torch.Generator(device=self.device)
+ seed_g.manual_seed(seed)
+
+ latent_frame_num = (F - 1) // self.vae_stride[0] + 1
+ noise = torch.randn(
+ 16,
+ latent_frame_num,
+ lat_h,
+ lat_w,
+ dtype=torch.float32,
+ generator=seed_g,
+ device=self.device)
+
+ msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
+ msk[:, 1:-1] = 0
+ msk = torch.concat([
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
+ ],
+ dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1, 2)[0]
+
+ if n_prompt == "":
+ n_prompt = self.sample_neg_prompt
+
+ # preprocess
+ if not self.t5_cpu:
+ self.text_encoder.model.to(self.device)
+ context = self.text_encoder([input_prompt], self.device)
+ context_null = self.text_encoder([n_prompt], self.device)
+ if offload_model:
+ self.text_encoder.model.cpu()
+ else:
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
+ context = [t.to(self.device) for t in context]
+ context_null = [t.to(self.device) for t in context_null]
+
+ self.clip.model.to(self.device)
+ clip_context = self.clip.visual(
+ [first_frame[:, None, :, :], last_frame[:, None, :, :]])
+ if offload_model:
+ self.clip.model.cpu()
+
+ encode_input = torch.concat([
+ torch.nn.functional.interpolate(
+ first_frame[None].to(self.device),
+ size=(first_frame_h, first_frame_w),
+ mode='bicubic').transpose(0, 1),
+ torch.zeros(3, F - 2, first_frame_h, first_frame_w, device = self.device),
+ torch.nn.functional.interpolate(
+ last_frame[None].to(self.device),
+ size=(first_frame_h, first_frame_w),
+ mode='bicubic').transpose(0, 1),
+ ],
+ dim=1)
+ with VAE_patch_parallel():
+ y = self.vae.encode([
+ encode_input
+ ])[0]
+ y = torch.concat([msk, y])
+
+ @contextmanager
+ def noop_no_sync():
+ yield
+
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
+
+ # evaluation mode
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
+
+ if sample_solver == 'unipc':
+ sample_scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sample_scheduler.set_timesteps(
+ sampling_steps, device=self.device, shift=shift)
+ timesteps = sample_scheduler.timesteps
+ elif sample_solver == 'dpm++':
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
+ num_train_timesteps=self.num_train_timesteps,
+ shift=1,
+ use_dynamic_shifting=False)
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
+ timesteps, _ = retrieve_timesteps(
+ sample_scheduler,
+ device=self.device,
+ sigmas=sampling_sigmas)
+ else:
+ raise NotImplementedError("Unsupported solver.")
+
+ # sample videos
+ latent = noise
+
+ arg_c = {
+ 'context': [context[0]],
+ 'clip_fea': clip_context,
+ 'seq_len': max_seq_len,
+ 'y': [y],
+ }
+
+ arg_null = {
+ 'context': context_null,
+ 'clip_fea': clip_context,
+ 'seq_len': max_seq_len,
+ 'y': [y],
+ }
+
+ arg_all = {
+ 'context': [context[0]] if get_classifier_free_guidance_rank()==0 else context_null,
+ 'clip_fea': clip_context,
+ 'seq_len': max_seq_len,
+ 'y': [y],
+ }
+
+ if offload_model:
+ torch.cuda.empty_cache()
+
+ self.model.to(self.device)
+ for _, t in enumerate(tqdm(timesteps)):
+ latent_model_input = [latent.to(self.device)]
+ timestep = [t]
+
+ timestep = torch.stack(timestep).to(self.device)
+
+ if get_classifier_free_guidance_world_size() == 2:
+ noise_pred = self.model(
+ latent_model_input, t=timestep, **arg_all)[0].to(
+ torch.device('cpu') if offload_model else self.device)
+ noise_pred_cond, noise_pred_uncond = get_cfg_group().all_gather(
+ noise_pred, separate_tensors=True
+ )
+ if offload_model:
+ torch.cuda.empty_cache()
+ else:
+ noise_pred_cond = self.model(
+ latent_model_input, t=timestep, **arg_c)[0].to(
+ torch.device('cpu') if offload_model else self.device)
+ if offload_model:
+ torch.cuda.empty_cache()
+ noise_pred_uncond = self.model(
+ latent_model_input, t=timestep, **arg_null)[0].to(
+ torch.device('cpu') if offload_model else self.device)
+ if offload_model:
+ torch.cuda.empty_cache()
+ noise_pred = noise_pred_uncond + guide_scale * (
+ noise_pred_cond - noise_pred_uncond)
+
+ latent = latent.to(
+ torch.device('cpu') if offload_model else self.device)
+
+ temp_x0 = sample_scheduler.step(
+ noise_pred.unsqueeze(0),
+ t,
+ latent.unsqueeze(0),
+ return_dict=False,
+ generator=seed_g)[0]
+ latent = temp_x0.squeeze(0)
+
+ x0 = [latent.to(self.device)]
+ del latent_model_input, timestep
+
+ if offload_model:
+ self.model.cpu()
+ torch.cuda.empty_cache()
+
+ with VAE_patch_parallel():
+ videos = self.vae.decode(x0)
+
+ del noise, latent
+ del sample_scheduler
+ def unwrap_fsdp(model):
+ if hasattr(self.model, '_fsdp_wrapped_module'):
+ return self.model._fsdp_wrapped_module
+ return model
+ unwrap_fsdp(self.model).freqs_list = None
+
+ if offload_model:
+ gc.collect()
+ torch.cuda.synchronize()
+ if dist.is_initialized():
+ dist.barrier()
+
+ return videos[0] if self.rank == 0 else None
diff --git a/wan/modules/attention.py b/wan/modules/attention.py
index 9c7cbfd..1e1a659 100644
--- a/wan/modules/attention.py
+++ b/wan/modules/attention.py
@@ -161,7 +161,7 @@ def attention(
opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD")
else:
out = attention_forward(q, k, v,
- opt_mode="manual", op_type="fused_attn_score", layout="BNSD")
+ opt_mode="manual", op_type="fused_attn_score", layout="BSND")
return out.to(qtype)
elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
diff --git a/wan/modules/attn_layer.py b/wan/modules/attn_layer.py
index ec75b7b..6901cd4 100644
--- a/wan/modules/attn_layer.py
+++ b/wan/modules/attn_layer.py
@@ -58,7 +58,7 @@ def __init__(
)
self.world_size = dist.get_world_size()
self.args = args
- self.video_size = ['480*832', '832*480', '480*720', '720*480']
+ self.video_size = ['480*832', '832*480', '480*720', '720*480', '1024*1024']
self.algo = int(os.getenv('ALGO', 0))
diff --git a/wan/modules/model.py b/wan/modules/model.py
index 93f6607..73c88da 100644
--- a/wan/modules/model.py
+++ b/wan/modules/model.py
@@ -14,6 +14,8 @@
__all__ = ['WanModel']
+T5_CONTEXT_TOKEN_NUMBER = 512
+FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
def sinusoidal_embedding_1d(dim, position):
# preprocess
@@ -39,21 +41,14 @@ def rope_params(max_seq_len, dim, theta=10000):
return freqs
-@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs_list):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
- s, n, c = x.size(1), x.size(2), x.size(3)
- output = []
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
- x_i = x[i, :s].reshape(1, s, n, c)
- cos, sin = freqs_list[i]
- x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True)
- output.append(x_i)
- return torch.cat(output).float()
+ cos, sin = freqs_list[0]
+ return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)
class WanRMSNorm(nn.Module):
@@ -214,8 +209,9 @@ def forward(self, x, context, context_lens):
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
- context_img = context[:, :257]
- context = context[:, 257:]
+ image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
+ context_img = context[:, :image_context_length]
+ context = context[:, image_context_length:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
@@ -283,7 +279,7 @@ def __init__(self,
nn.Linear(ffn_dim, dim))
# modulation
- self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
# Attention_cache
self.cache = None
@@ -353,7 +349,7 @@ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
self.head = nn.Linear(dim, out_dim)
# modulation
- self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
def forward(self, x, e):
r"""
@@ -370,15 +366,21 @@ def forward(self, x, e):
class MLPProj(torch.nn.Module):
- def __init__(self, in_dim, out_dim):
+ def __init__(self, in_dim, out_dim, flf_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
+ if flf_pos_emb: # NOTE: we only use this for `flf2v`
+ self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
def forward(self, image_embeds):
+ if hasattr(self, 'emb_pos'):
+ bs, n, d = image_embeds.shape
+ image_embeds = image_embeds.view(-1, 2 * n, d)
+ image_embeds = image_embeds + self.emb_pos
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@@ -415,7 +417,7 @@ def __init__(self,
Args:
model_type (`str`, *optional*, defaults to 't2v'):
- Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@@ -448,7 +450,7 @@ def __init__(self,
super().__init__()
- assert model_type in ['t2v', 'i2v']
+ assert model_type in ['t2v', 'i2v', 'flf2v']
self.model_type = model_type
self.patch_size = patch_size
@@ -498,8 +500,8 @@ def __init__(self,
],
dim=1)
- if model_type == 'i2v':
- self.img_emb = MLPProj(1280, dim)
+ if model_type == 'i2v' or model_type == 'flf2v':
+ self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
# initialize weights
self.init_weights()
@@ -528,7 +530,7 @@ def forward(
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
- CLIP image features for image-to-video mode
+ CLIP image features for image-to-video mode or first-last-frame-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
@@ -536,7 +538,7 @@ def forward(
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
- if self.model_type == 'i2v':
+ if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
@@ -575,7 +577,7 @@ def forward(
]))
if clip_fea is not None:
- context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
context = torch.concat([context_clip, context], dim=1)
if self.freqs_list is None:
diff --git a/wan/utils/prompt_extend.py b/wan/utils/prompt_extend.py
index f0ef38a..f3ecea7 100644
--- a/wan/utils/prompt_extend.py
+++ b/wan/utils/prompt_extend.py
@@ -7,7 +7,7 @@
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
-from typing import Optional, Union
+from typing import Optional, Union, List
import dashscope
import torch
@@ -47,7 +47,7 @@
'''4. Prompts should match the user’s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
- '''7. The revised prompt should be around 80-100 characters long.\n''' \
+ '''7. The revised prompt should be around 80-100 words long.\n''' \
'''Revised prompt examples:\n''' \
'''1. Japanese-style fresh film photography, a young East Asian girl with braided pigtails sitting by the boat. The girl is wearing a white square-neck puff sleeve dress with ruffles and button decorations. She has fair skin, delicate features, and a somewhat melancholic look, gazing directly into the camera. Her hair falls naturally, with bangs covering part of her forehead. She is holding onto the boat with both hands, in a relaxed posture. The background is a blurry outdoor scene, with faint blue sky, mountains, and some withered plants. Vintage film texture photo. Medium shot half-body portrait in a seated position.\n''' \
'''2. Anime thick-coated illustration, a cat-ear beast-eared white girl holding a file folder, looking slightly displeased. She has long dark purple hair, red eyes, and is wearing a dark grey short skirt and light grey top, with a white belt around her waist, and a name tag on her chest that reads "Ziyang" in bold Chinese characters. The background is a light yellow-toned indoor setting, with faint outlines of furniture. There is a pink halo above the girl's head. Smooth line Japanese cel-shaded style. Close-up half-body slightly overhead view.\n''' \
@@ -97,6 +97,59 @@
'''Directly output the rewritten English text.'''
+VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写
+任务要求:
+1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写
+2. 对于过于简短的用户输入,在不改变原意前提下,合理推断并补充细节,使得画面更加完整好看;
+3. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)、画面风格、空间关系、镜头景别;
+4. 整体中文输出,保留引号、书名号中原文以及重要的输入信息,不要改写;
+5. Prompt应匹配符合用户意图且精准细分的风格描述。如果用户未指定,则根据用户提供的照片的风格,你需要仔细分析照片的风格,并参考风格进行改写。
+6. 如果Prompt是古诗词,应该在生成的Prompt中强调中国古典元素,避免出现西方、现代、外国场景;
+7. 你需要强调输入中的运动信息和不同的镜头运镜;
+8. 你的输出应当带有自然运动属性,需要根据描述主体目标类别增加这个目标的自然动作,描述尽可能用简单直接的动词;
+9. 你需要尽可能的参考图片的细节信息,如人物动作、服装、背景等,强调照片的细节元素;
+10. 你需要强调两画面可能出现的潜在变化,如“走进”,“出现”,“变身成”,“镜头左移”,“镜头右移动”,“镜头上移动”, “镜头下移”等等;
+11. 无论用户输入那种语言,你都需要输出中文;
+12. 改写后的prompt字数控制在80-100字左右;
+改写后 prompt 示例:
+1. 日系小清新胶片写真,扎着双麻花辫的年轻东亚女孩坐在船边。女孩穿着白色方领泡泡袖连衣裙,裙子上有褶皱和纽扣装饰。她皮肤白皙,五官清秀,眼神略带忧郁,直视镜头。女孩的头发自然垂落,刘海遮住部分额头。她双手扶船,姿态自然放松。背景是模糊的户外场景,隐约可见蓝天、山峦和一些干枯植物。复古胶片质感照片。中景半身坐姿人像。
+2. 二次元厚涂动漫插画,一个猫耳兽耳白人少女手持文件夹,神情略带不满。她深紫色长发,红色眼睛,身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,上面写着黑体中文"紫阳"。淡黄色调室内背景,隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。
+3. CG游戏概念数字艺术,一只巨大的鳄鱼张开大嘴,背上长着树木和荆棘。鳄鱼皮肤粗糙,呈灰白色,像是石头或木头的质感。它背上生长着茂盛的树木、灌木和一些荆棘状的突起。鳄鱼嘴巴大张,露出粉红色的舌头和锋利的牙齿。画面背景是黄昏的天空,远处有一些树木。场景整体暗黑阴冷。近景,仰视视角。
+4. 美剧宣传海报风格,身穿黄色防护服的Walter White坐在金属折叠椅上,上方无衬线英文写着"Breaking Bad",周围是成堆的美元和蓝色塑料储物箱。他戴着眼镜目光直视前方,身穿黄色连体防护服,双手放在膝盖上,神态稳重自信。背景是一个废弃的阴暗厂房,窗户透着光线。带有明显颗粒质感纹理。中景,镜头下移。
+请直接输出改写后的文本,不要进行多余的回复。"""
+
+VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES = \
+ '''You are a prompt optimization specialist whose goal is to rewrite the user's input prompts into high-quality English prompts by referring to the details of the user's input images, making them more complete and expressive while maintaining the original meaning. You need to integrate the content of the user's photo with the input prompt for the rewrite, strictly adhering to the formatting of the examples provided.\n''' \
+ '''Task Requirements:\n''' \
+ '''1. The user will input two images, the first is the first frame of the video, and the second is the last frame of the video. You need to integrate the content of the two photos with the input prompt for the rewrite.\n''' \
+ '''2. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \
+ '''3. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \
+ '''4. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \
+ '''5. The prompt should match the user’s intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
+ '''6. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \
+ '''7. You need to emphasize movement information in the input and different camera angles;\n''' \
+ '''8. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \
+ '''9. You should reference the detailed information in the image, such as character actions, clothing, backgrounds, and emphasize the details in the photo;\n''' \
+ '''10. You need to emphasize potential changes that may occur between the two frames, such as "walking into", "appearing", "turning into", "camera left", "camera right", "camera up", "camera down", etc.;\n''' \
+ '''11. Control the rewritten prompt to around 80-100 words.\n''' \
+ '''12. No matter what language the user inputs, you must always output in English.\n''' \
+ '''Example of the rewritten English prompt:\n''' \
+ '''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \
+ '''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \
+ '''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \
+ '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
+ '''Directly output the rewritten English text.'''
+
+SYSTEM_PROMPT_TYPES = {
+ int(b'000', 2): LM_EN_SYS_PROMPT,
+ int(b'001', 2): LM_ZH_SYS_PROMPT,
+ int(b'010', 2): VL_EN_SYS_PROMPT,
+ int(b'011', 2): VL_ZH_SYS_PROMPT,
+ int(b'110', 2): VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES,
+ int(b'111', 2): VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES
+}
+
+
@dataclass
class PromptOutput(object):
status: bool
@@ -128,21 +181,25 @@ def extend_with_img(self,
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
pass
- def decide_system_prompt(self, tar_lang="zh"):
+ def decide_system_prompt(self, tar_lang="zh", multi_images_input=False):
zh = tar_lang == "zh"
- if zh:
- return LM_ZH_SYS_PROMPT if not self.is_vl else VL_ZH_SYS_PROMPT
- else:
- return LM_EN_SYS_PROMPT if not self.is_vl else VL_EN_SYS_PROMPT
+ self.is_vl |= multi_images_input
+ task_type = zh + (self.is_vl << 1) + (multi_images_input << 2)
+ return SYSTEM_PROMPT_TYPES[task_type]
def __call__(self,
prompt,
+ system_prompt=None,
tar_lang="zh",
image=None,
seed=-1,
*args,
**kwargs):
- system_prompt = self.decide_system_prompt(tar_lang=tar_lang)
+ if system_prompt is None:
+ system_prompt = self.decide_system_prompt(
+ tar_lang=tar_lang,
+ multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
+ )
if seed < 0:
seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl:
@@ -232,38 +289,42 @@ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
def extend_with_img(self,
prompt,
system_prompt,
- image: Union[Image.Image, str] = None,
+ image: Union[List[Image.Image], List[str], Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
- if isinstance(image, str):
- image = Image.open(image).convert('RGB')
- w = image.width
- h = image.height
- area = min(w * h, self.max_image_size)
- aspect_ratio = h / w
- resized_h = round(math.sqrt(area * aspect_ratio))
- resized_w = round(math.sqrt(area / aspect_ratio))
- image = image.resize((resized_w, resized_h))
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
- image.save(f.name)
- fname = f.name
- image_path = f"file://{f.name}"
+
+ def ensure_image(_image):
+ if isinstance(_image, str):
+ _image = Image.open(_image).convert('RGB')
+ w = _image.width
+ h = _image.height
+ area = min(w * h, self.max_image_size)
+ aspect_ratio = h / w
+ resized_h = round(math.sqrt(area * aspect_ratio))
+ resized_w = round(math.sqrt(area / aspect_ratio))
+ _image = _image.resize((resized_w, resized_h))
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
+ _image.save(f.name)
+ image_path = f"file://{f.name}"
+ return image_path
+ if not isinstance(image, (list, tuple)):
+ image = [image]
+ image_path_list = [ensure_image(_image) for _image in image]
+ role_content = [
+ {"text": prompt},
+ *[{"image": image_path} for image_path in image_path_list]
+ ]
+ system_content = [{"text": system_prompt}]
prompt = f"{prompt}"
messages = [
{
'role': 'system',
- 'content': [{
- "text": system_prompt
- }]
+ 'content': system_content
},
{
'role': 'user',
- 'content': [{
- "text": prompt
- }, {
- "image": image_path
- }]
+ 'content': role_content
},
]
response = None
@@ -286,7 +347,8 @@ def extend_with_img(self,
except Exception as e:
exception = e
result_prompt = result_prompt.replace('\n', '\\n')
- os.remove(fname)
+ for image_path in image_path_list:
+ os.remove(image_path.removeprefix('file://'))
return PromptOutput(
status=status,
@@ -397,30 +459,36 @@ def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
def extend_with_img(self,
prompt,
system_prompt,
- image: Union[Image.Image, str] = None,
+ image: Union[List[Image.Image], List[str], Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
self.model = self.model.to(self.device)
+
+ if not isinstance(image, (list, tuple)):
+ image = [image]
+
+ system_content = [{
+ "type": "text",
+ "text": system_prompt
+ }]
+ role_content = [
+ {
+ "type": "text",
+ "text": prompt
+ },
+ *[
+ {"image": image_path} for image_path in image
+ ]
+ ]
+
messages = [{
'role': 'system',
- 'content': [{
- "type": "text",
- "text": system_prompt
- }]
+ 'content': system_content,
}, {
"role":
"user",
- "content": [
- {
- "type": "image",
- "image": image,
- },
- {
- "type": "text",
- "text": prompt
- },
- ],
+ "content": role_content,
}]
# Preparation for inference
@@ -500,7 +568,8 @@ def extend_with_img(self,
# test case for prompt-image extend
ds_model_name = "qwen-vl-max"
#qwen_model_name = "./models/Qwen2.5-VL-3B-Instruct/" #VRAM: 9686MiB
- qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492
+ # qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492
+ qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct/"
image = "./examples/i2v_input.JPG"
# test dashscope api why image_path is local directory; skip
@@ -541,3 +610,26 @@ def extend_with_img(self,
en_prompt, tar_lang="en", image=image, seed=seed)
print("VL qwen vl en result -> en",
qwen_result.prompt) # , qwen_result.system_prompt)
+ # test multi images
+ image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
+ prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
+ en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
+ "aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
+ "resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
+ "architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
+
+ dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
+ dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
+ print("VL dashscope result -> zh", dashscope_result.prompt)
+
+ dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
+ dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
+ print("VL dashscope en result -> zh", dashscope_result.prompt)
+
+ qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
+ qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
+ print("VL qwen result -> zh", qwen_result.prompt)
+
+ qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
+ qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
+ print("VL qwen en result -> zh", qwen_result.prompt)
\ No newline at end of file