From 4c4203160f8526a403e3c0381b9fa13f6c208347 Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Mon, 23 Dec 2024 20:46:57 +0800 Subject: [PATCH] update usp --- examples/consisid_example.py | 6 +- examples/consisid_usp_example.py | 256 +++++++++++++++++++++++++++++++ examples/run_consisid_usp.sh | 41 +++++ 3 files changed, 301 insertions(+), 2 deletions(-) create mode 100644 examples/consisid_usp_example.py create mode 100644 examples/run_consisid_usp.sh diff --git a/examples/consisid_example.py b/examples/consisid_example.py index 17c4ffe5..067435e1 100644 --- a/examples/consisid_example.py +++ b/examples/consisid_example.py @@ -3,9 +3,11 @@ import time import torch import torch.distributed + from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer from diffusers.utils import export_to_video from huggingface_hub import snapshot_download + from xfuser import xFuserConsisIDPipeline, xFuserArgs from xfuser.config import FlexibleArgumentParser from xfuser.core.distributed import ( @@ -32,7 +34,7 @@ def main(): else: print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.") - # 2. Load Pipeline. + # 2. Load Pipeline device = torch.device(f"cuda:{local_rank}") pipe = xFuserConsisIDPipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, @@ -113,4 +115,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/consisid_usp_example.py b/examples/consisid_usp_example.py new file mode 100644 index 00000000..b02a26b5 --- /dev/null +++ b/examples/consisid_usp_example.py @@ -0,0 +1,256 @@ +import functools +from typing import Optional, Tuple, Any, Dict + +import logging +import os +import time +import torch + +from diffusers import DiffusionPipeline, ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from diffusers.utils import export_to_video +from huggingface_hub import snapshot_download + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) +from xfuser.model_executor.layers.attention_processor import xFuserConsisIDAttnProcessor2_0 + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + original_forward = transformer.forward + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: torch.LongTensor = None, + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + **kwargs, + ): + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + temporal_size = hidden_states.shape[1] + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: + timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + if image_rotary_emb is not None: + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + dim_thw = freqs.shape[-1] + freqs = freqs.reshape(temporal_size, -1, dim_thw) + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + freqs = freqs.reshape(-1, dim_thw) + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + for block in transformer.transformer_blocks: + block.attn1.processor = xFuserConsisIDAttnProcessor2_0() + + output = original_forward( + hidden_states, + encoder_hidden_states, + timestep=timestep, + timestep_cond=timestep_cond, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + id_cond=id_cond, + id_vit_hidden=id_vit_hidden, + **kwargs, + ) + + return_dict = not isinstance(output, tuple) + sample = output[0] + sample = get_sp_group().all_gather(sample, dim=-2) + sample = get_cfg_group().all_gather(sample, dim=0) + if return_dict: + return output.__class__(sample, *output[1:]) + return (sample, *output[1:]) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + original_patch_embed_forward = transformer.patch_embed.forward + + @functools.wraps(transformer.patch_embed.__class__.forward) + def new_patch_embed( + self, text_embeds: torch.Tensor, image_embeds: torch.Tensor + ): + text_embeds = get_sp_group().all_gather(text_embeds.contiguous(), dim=-2) + image_embeds = get_sp_group().all_gather(image_embeds.contiguous(), dim=-2) + batch, num_frames, channels, height, width = image_embeds.shape + text_len = text_embeds.shape[-2] + + output = original_patch_embed_forward(text_embeds, image_embeds) + + text_embeds = output[:,:text_len,:] + image_embeds = output[:,text_len:,:].reshape(batch, num_frames, -1, output.shape[-1]) + + text_embeds = torch.chunk(text_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + image_embeds = torch.chunk(image_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + image_embeds = image_embeds.reshape(batch, -1, image_embeds.shape[-1]) + return torch.cat([text_embeds, image_embeds], dim=1) + + new_patch_embed = new_patch_embed.__get__(transformer.patch_embed) + transformer.patch_embed.forward = new_patch_embed + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID" + + # 1. Prepare all the Checkpoints + if not os.path.exists(engine_config.model_config.model): + print("Base Model not found, downloading from Hugging Face...") + snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model) + else: + print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.") + + # 2. Load Pipeline + device = torch.device(f"cuda:{local_rank}") + pipe = ConsisIDPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + pipe = pipe.to(device) + + face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16) + ) + + if args.enable_tiling: + pipe.vae.enable_tiling() + + if args.enable_slicing: + pipe.vae.enable_slicing() + + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + parallelize_transformer(pipe) + + # 3. Prepare Model Input + id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + face_helper_1, + face_clip_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + face_main_model, + device, + torch.bfloat16, + input_config.img_file_path, + is_align_face=True, + ) + + # 4. Generate Identity-Preserving Video + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") + + # one step to warmup the torch compiler + output = pipe( + image=image, + prompt=input_config.prompt[0], + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + num_inference_steps=1, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6.0, + use_dynamic_cfg=False, + ).frames[0] + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + image=image, + prompt=input_config.prompt[0], + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6.0, + use_dynamic_cfg=False, + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/run_consisid_usp.sh b/examples/run_consisid_usp.sh new file mode 100644 index 00000000..24ccd84b --- /dev/null +++ b/examples/run_consisid_usp.sh @@ -0,0 +1,41 @@ +#!/bin/bash +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# ConsisID configuration +SCRIPT="consisid_usp_example.py" +MODEL_ID="/cfs/dit/CogVideoX1.5-5B" +INFERENCE_STEP=50 + +mkdir -p ./results + +# ConsisID specific task args +TASK_ARGS="--height 480 --width 720 --num_frames 49" + +# ConsisID parallel configuration +N_GPUS=4 +PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 2" +# CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +# ENABLE_TILING="--enable_tiling" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --master_port=1234 --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ +--img_file_path "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$ENABLE_TILING \ +$COMPILE_FLAG \ No newline at end of file