Skip to content

Commit

Permalink
update usp
Browse files Browse the repository at this point in the history
  • Loading branch information
SHYuanBest committed Dec 23, 2024
1 parent ae678fb commit 4c42031
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/consisid_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import time
import torch
import torch.distributed

from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download

from xfuser import xFuserConsisIDPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -32,7 +34,7 @@ def main():
else:
print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.")

# 2. Load Pipeline.
# 2. Load Pipeline
device = torch.device(f"cuda:{local_rank}")
pipe = xFuserConsisIDPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
Expand Down Expand Up @@ -113,4 +115,4 @@ def main():


if __name__ == "__main__":
main()
main()
256 changes: 256 additions & 0 deletions examples/consisid_usp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import functools
from typing import Optional, Tuple, Any, Dict

import logging
import os
import time
import torch

from diffusers import DiffusionPipeline, ConsisIDPipeline
from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download

from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)
from xfuser.model_executor.layers.attention_processor import xFuserConsisIDAttnProcessor2_0

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: torch.LongTensor = None,
timestep_cond: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
**kwargs,
):
if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0:
get_runtime_state().split_text_embed_in_sp = False
else:
get_runtime_state().split_text_embed_in_sp = True

temporal_size = hidden_states.shape[1]
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
if image_rotary_emb is not None:
freqs_cos, freqs_sin = image_rotary_emb

def get_rotary_emb_chunk(freqs):
dim_thw = freqs.shape[-1]
freqs = freqs.reshape(temporal_size, -1, dim_thw)
freqs = torch.chunk(freqs, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
freqs = freqs.reshape(-1, dim_thw)
return freqs

freqs_cos = get_rotary_emb_chunk(freqs_cos)
freqs_sin = get_rotary_emb_chunk(freqs_sin)
image_rotary_emb = (freqs_cos, freqs_sin)

for block in transformer.transformer_blocks:
block.attn1.processor = xFuserConsisIDAttnProcessor2_0()

output = original_forward(
hidden_states,
encoder_hidden_states,
timestep=timestep,
timestep_cond=timestep_cond,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
id_cond=id_cond,
id_vit_hidden=id_vit_hidden,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_sp_group().all_gather(sample, dim=-2)
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

original_patch_embed_forward = transformer.patch_embed.forward

@functools.wraps(transformer.patch_embed.__class__.forward)
def new_patch_embed(
self, text_embeds: torch.Tensor, image_embeds: torch.Tensor
):
text_embeds = get_sp_group().all_gather(text_embeds.contiguous(), dim=-2)
image_embeds = get_sp_group().all_gather(image_embeds.contiguous(), dim=-2)
batch, num_frames, channels, height, width = image_embeds.shape
text_len = text_embeds.shape[-2]

output = original_patch_embed_forward(text_embeds, image_embeds)

text_embeds = output[:,:text_len,:]
image_embeds = output[:,text_len:,:].reshape(batch, num_frames, -1, output.shape[-1])

text_embeds = torch.chunk(text_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
image_embeds = torch.chunk(image_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
image_embeds = image_embeds.reshape(batch, -1, image_embeds.shape[-1])
return torch.cat([text_embeds, image_embeds], dim=1)

new_patch_embed = new_patch_embed.__get__(transformer.patch_embed)
transformer.patch_embed.forward = new_patch_embed

def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID"

# 1. Prepare all the Checkpoints
if not os.path.exists(engine_config.model_config.model):
print("Base Model not found, downloading from Hugging Face...")
snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model)
else:
print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.")

# 2. Load Pipeline
device = torch.device(f"cuda:{local_rank}")
pipe = ConsisIDPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
pipe = pipe.to(device)

face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16)
)

if args.enable_tiling:
pipe.vae.enable_tiling()

if args.enable_slicing:
pipe.vae.enable_slicing()

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_video_input_parameters(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)
parallelize_transformer(pipe)

# 3. Prepare Model Input
id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
face_helper_1,
face_clip_model,
face_helper_2,
eva_transform_mean,
eva_transform_std,
face_main_model,
device,
torch.bfloat16,
input_config.img_file_path,
is_align_face=True,
)

# 4. Generate Identity-Preserving Video
if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

# one step to warmup the torch compiler
output = pipe(
image=image,
prompt=input_config.prompt[0],
id_vit_hidden=id_vit_hidden,
id_cond=id_cond,
kps_cond=face_kps,
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
num_inference_steps=1,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6.0,
use_dynamic_cfg=False,
).frames[0]

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
image=image,
prompt=input_config.prompt[0],
id_vit_hidden=id_vit_hidden,
id_cond=id_cond,
kps_cond=face_kps,
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6.0,
use_dynamic_cfg=False,
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions examples/run_consisid_usp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
set -x

export PYTHONPATH=$PWD:$PYTHONPATH

# ConsisID configuration
SCRIPT="consisid_usp_example.py"
MODEL_ID="/cfs/dit/CogVideoX1.5-5B"
INFERENCE_STEP=50

mkdir -p ./results

# ConsisID specific task args
TASK_ARGS="--height 480 --width 720 --num_frames 49"

# ConsisID parallel configuration
N_GPUS=4
PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 2"
# CFG_ARGS="--use_cfg_parallel"

# Uncomment and modify these as needed
# PIPEFUSION_ARGS="--num_pipeline_patch 8"
# OUTPUT_ARGS="--output_type latent"
# PARALLLEL_VAE="--use_parallel_vae"
# ENABLE_TILING="--enable_tiling"
# COMPILE_FLAG="--use_torch_compile"

torchrun --master_port=1234 --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \
--img_file_path "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \
$CFG_ARGS \
$PARALLLEL_VAE \
$ENABLE_TILING \
$COMPILE_FLAG

0 comments on commit 4c42031

Please sign in to comment.