Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: complete sequence parallelism for CogVideo #285

Merged
merged 12 commits into from
Sep 24, 2024
10 changes: 5 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ jobs:
- run: mkdir ~/xDiT
- run: unzip ~/xDiT.zip -d ~/xDiT
- name: Setup docker
run: docker run --rm --name xfuser_test_docker -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash
run: docker run --rm --name xfuser_test_docker_${{github.repository_owner_id}} -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash
- name: Install xfuser
run: docker exec -w /code xfuser_test_docker pip3.11 install -e .
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} pip3.11 install -e .
- name: Test xfuser
run: docker exec -w /code xfuser_test_docker sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel"
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel"
clear-env:
needs: setup-env-and-test
runs-on: [self-hosted, linux, x64]
steps:
- name: Remove Files
run: docker exec -w /code xfuser_test_docker sh -c "rm -r *"
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "rm -r *"
- name: Destroy docker
run: docker stop xfuser_test_docker
run: docker stop xfuser_test_docker_${{github.repository_owner_id}}
2 changes: 1 addition & 1 deletion examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ declare -A MODEL_CONFIGS=(
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-schnell 4"
["HunyuanDiT"]="hunyuandit_example.py /mnt/models/SD/HunyuanDiT-v1.2-Diffusers 50"
["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 1"
["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 20"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
Expand Down
17 changes: 16 additions & 1 deletion xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,22 @@ def forward(
):
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
query = torch.cat([query, joint_tensor_query], dim=1)
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and joint_tensor_query is None:
raise ValueError(
f"joint_tensor_query must not be None when joint_strategy is not None"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
pass

ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = joint_tensor_key.shape[-2] // ulysses_world_size
Expand Down
46 changes: 22 additions & 24 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,6 @@ def __init__(
assert (to_k.bias is None) == (to_v.bias is None)
assert to_k.weight.shape == to_v.weight.shape

in_size, out_size = to_k.in_features, to_k.out_features
to_kv = nn.Linear(
in_size,
out_size * 2,
bias=to_k.bias is not None,
device=to_k.weight.device,
dtype=to_k.weight.dtype,
)
to_kv.weight.data[:out_size].copy_(to_k.weight.data)
to_kv.weight.data[out_size:].copy_(to_v.weight.data)

if to_k.bias is not None:
assert to_v.bias is not None
to_kv.bias.data[:out_size].copy_(to_k.bias.data)
to_kv.bias.data[out_size:].copy_(to_v.bias.data)

self.to_kv = to_kv


class xFuserAttentionProcessorRegister:
_XFUSER_ATTENTION_PROCESSOR_MAPPING = {}

Expand Down Expand Up @@ -878,7 +859,6 @@ def __call__(
encoder_hidden_states
)

# kv = attn.to_kv(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

Expand Down Expand Up @@ -1013,12 +993,12 @@ def __init__(self):
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserJointLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
self.hybrid_seq_parallel_attn = xFuserJointLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
Expand All @@ -1040,6 +1020,7 @@ def __call__(
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
latent_seq_length = hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

Expand Down Expand Up @@ -1095,18 +1076,33 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_strategy="none",
joint_tensor_query=encoder_query,
joint_tensor_key=encoder_key,
joint_tensor_value=encoder_value,
joint_strategy="front",
)

hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
Expand Down Expand Up @@ -1141,12 +1137,14 @@ def __call__(
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------

assert text_seq_length + latent_seq_length == hidden_states.shape[1]
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
[text_seq_length, latent_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
60 changes: 42 additions & 18 deletions xfuser/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
super().__init__(
module=patch_embedding,
)
self.module: CogVideoXPatchEmbed

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Expand All @@ -124,6 +125,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
sum_height = (
get_runtime_state().input_config.height
// get_runtime_state().vae_scale_factor_spatial
)
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape

Expand All @@ -133,28 +138,47 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]

if get_runtime_state().patch_mode:
start, end = get_runtime_state().pp_patches_token_start_end_idx_global[
get_runtime_state().pipeline_patch_idx
]
image_embeds = image_embeds[
:,
start:end,
:,
]
else:
image_embeds_list = [
image_embeds[

if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != sum_height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)

pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

if (
self.sample_height != sum_height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(sum_height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding

# extract the image part of the positional embedding
pos_embedding = pos_embedding[:, self.max_text_seq_length :]

# slice the positional embedding
post_patch_height = sum_height // self.patch_size
post_patch_width = width // self.patch_size
post_time_compression_frames = (pre_time_compression_frames - 1) // self.temporal_compression_ratio + 1

pos_embed_list = [
pos_embedding[
:,
get_runtime_state()
.pp_patches_token_start_end_idx_global[i][0] : get_runtime_state()
.pp_patches_token_start_end_idx_global[i][1],
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][0]:
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][1],
:,
]
for i in range(get_runtime_state().num_pipeline_patch)
for i in range(post_time_compression_frames)
]
image_embeds = torch.cat(image_embeds_list, dim=1)
pos_embedding = torch.cat(pos_embed_list, dim=1)

image_embeds = image_embeds + pos_embedding

embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d],
submodule_classes_to_wrap=[nn.Conv2d, CogVideoXPatchEmbed],
submodule_name_to_wrap=["attn1"]
)

Expand Down
35 changes: 14 additions & 21 deletions xfuser/model_executor/pipelines/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = self._process_cfg_split_batch_latte(prompt_embeds, negative_prompt_embeds)

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
Expand Down Expand Up @@ -272,9 +271,11 @@ def __call__(
if self.interrupt:
continue

latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if do_classifier_free_guidance:
latent_model_input = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)

latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
Expand All @@ -295,21 +296,15 @@ def __call__(
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(
1
- math.cos(
math.pi
* (
(num_inference_steps - t.item())
/ num_inference_steps
)
** 5.0
)
)
/ 2
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
Expand Down Expand Up @@ -344,9 +339,7 @@ def __call__(
"negative_prompt_embeds", negative_prompt_embeds
)

if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if get_sequence_parallel_world_size() > 1:
Expand Down