Skip to content

Commit

Permalink
fix: use xFuserJointLongContextAttention as the underlying SP method …
Browse files Browse the repository at this point in the history
…for CogVideoX (#280)
  • Loading branch information
xibosun authored Sep 23, 2024
1 parent 40fade6 commit fe11de0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 30 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ jobs:
- run: mkdir ~/xDiT
- run: unzip ~/xDiT.zip -d ~/xDiT
- name: Setup docker
run: docker run --rm --name xfuser_test_docker -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash
run: docker run --rm --name xfuser_test_docker_${{github.repository_owner_id}} -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash
- name: Install xfuser
run: docker exec -w /code xfuser_test_docker pip3.11 install -e .
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} pip3.11 install -e .
- name: Test xfuser
run: docker exec -w /code xfuser_test_docker sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel"
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel"
clear-env:
needs: setup-env-and-test
runs-on: [self-hosted, linux, x64]
steps:
- name: Remove Files
run: docker exec -w /code xfuser_test_docker sh -c "rm -r *"
run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "rm -r *"
- name: Destroy docker
run: docker stop xfuser_test_docker
run: docker stop xfuser_test_docker_${{github.repository_owner_id}}
17 changes: 16 additions & 1 deletion xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,22 @@ def forward(
):
# 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
query = torch.cat([query, joint_tensor_query], dim=1)
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and joint_tensor_query is None:
raise ValueError(
f"joint_tensor_query must not be None when joint_strategy is not None"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
pass

ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = joint_tensor_key.shape[-2] // ulysses_world_size
Expand Down
46 changes: 22 additions & 24 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,6 @@ def __init__(
assert (to_k.bias is None) == (to_v.bias is None)
assert to_k.weight.shape == to_v.weight.shape

in_size, out_size = to_k.in_features, to_k.out_features
to_kv = nn.Linear(
in_size,
out_size * 2,
bias=to_k.bias is not None,
device=to_k.weight.device,
dtype=to_k.weight.dtype,
)
to_kv.weight.data[:out_size].copy_(to_k.weight.data)
to_kv.weight.data[out_size:].copy_(to_v.weight.data)

if to_k.bias is not None:
assert to_v.bias is not None
to_kv.bias.data[:out_size].copy_(to_k.bias.data)
to_kv.bias.data[out_size:].copy_(to_v.bias.data)

self.to_kv = to_kv


class xFuserAttentionProcessorRegister:
_XFUSER_ATTENTION_PROCESSOR_MAPPING = {}

Expand Down Expand Up @@ -878,7 +859,6 @@ def __call__(
encoder_hidden_states
)

# kv = attn.to_kv(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

Expand Down Expand Up @@ -1013,12 +993,12 @@ def __init__(self):
)
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserJointLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
self.hybrid_seq_parallel_attn = xFuserJointLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
Expand All @@ -1040,6 +1020,7 @@ def __call__(
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
latent_seq_length = hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

Expand Down Expand Up @@ -1095,18 +1076,33 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
key,
value,
dropout_p=0.0,
causal=False,
joint_strategy="none",
joint_tensor_query=encoder_query,
joint_tensor_key=encoder_key,
joint_tensor_value=encoder_value,
joint_strategy="front",
)

hidden_states = hidden_states.reshape(
batch_size, -1, attn.heads * head_dim
)
Expand Down Expand Up @@ -1141,12 +1137,14 @@ def __call__(
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
#! ---------------------------------------- ATTENTION ----------------------------------------

assert text_seq_length + latent_seq_length == hidden_states.shape[1]
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)


encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
[text_seq_length, latent_seq_length], dim=1
)
return hidden_states, encoder_hidden_states

0 comments on commit fe11de0

Please sign in to comment.