Skip to content
Binary file added examples/flf2v_input_first_frame.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/flf2v_input_last_frame.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
131 changes: 128 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
"image":
"examples/i2v_input.JPG",
},
"flf2v-14B": {
"prompt":
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
},
}


Expand All @@ -60,6 +68,8 @@ def _validate_args(args):
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
if "flf2v" in args.task:
args.sample_shift = 16

# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
Expand Down Expand Up @@ -187,7 +197,17 @@ def _parse_args():
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
help="[image to video] The image to generate the video from.")
parser.add_argument(
"--first_frame",
type=str,
default=None,
help="[first-last frame to video] The image (first frame) to generate the video from.")
parser.add_argument(
"--last_frame",
type=str,
default=None,
help="[first-last frame to video] The image (last frame) to generate the video from.")
parser.add_argument(
"--image",
type=str,
Expand Down Expand Up @@ -305,7 +325,7 @@ def generate(args):
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
Expand Down Expand Up @@ -440,7 +460,7 @@ def generate(args):
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")

else:
elif "i2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
Expand Down Expand Up @@ -561,7 +581,112 @@ def generate(args):
stream.synchronize()
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"]
args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input first frame: {args.first_frame}")
logging.info(f"Input last frame: {args.last_frame}")
first_frame = Image.open(args.first_frame).convert("RGB")
last_frame = Image.open(args.last_frame).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=[first_frame, last_frame],
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")

logging.info("Creating WanFLF2V pipeline.")
wan_flf2v = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
use_vae_parallel=args.vae_parallel,
)

transformer = wan_flf2v.model
if args.use_attentioncache:
config = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.blocks),
steps_count=args.sample_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
else:
config = CacheConfig(
method="attention_cache",
blocks_count=len(transformer.blocks),
steps_count=args.sample_steps
)
cache = CacheAgent(config)
if args.dit_fsdp:
for block in transformer._fsdp_wrapped_module.blocks:
block._fsdp_wrapped_module.cache = cache
block._fsdp_wrapped_module.args = args
else:
for block in transformer.blocks:
block.cache = cache
block.args = args

logging.info(f"Warm up 2 steps...")
video = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=2,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info("Generating video ...")
stream.synchronize()
begin = time.time()

video = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
stream.synchronize()
end = time.time()
logging.info(f"Generating video used time {end - begin: .4f}s")

if rank == 0:
if args.save_file is None:
Expand Down
Loading