From fe11de03b84bd5aae9e05362019a6ff3c798c4f2 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Mon, 23 Sep 2024 13:45:08 +0800 Subject: [PATCH] fix: use xFuserJointLongContextAttention as the underlying SP method for CogVideoX (#280) --- .github/workflows/build.yml | 10 ++-- .../long_ctx_attention/hybrid/attn_layer.py | 17 ++++++- .../layers/attention_processor.py | 46 +++++++++---------- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 939fa46f..0a220142 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -27,16 +27,16 @@ jobs: - run: mkdir ~/xDiT - run: unzip ~/xDiT.zip -d ~/xDiT - name: Setup docker - run: docker run --rm --name xfuser_test_docker -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash + run: docker run --rm --name xfuser_test_docker_${{github.repository_owner_id}} -d -i -t --runtime=nvidia --gpus all -v /cfs:/cfs -v /mnt:/mnt -v ~/xDiT:/code xfuser_cicd/test-py_3_11-torch_2_4_1 /bin/bash - name: Install xfuser - run: docker exec -w /code xfuser_test_docker pip3.11 install -e . + run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} pip3.11 install -e . - name: Test xfuser - run: docker exec -w /code xfuser_test_docker sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel" + run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "torchrun --nproc_per_node=8 ./examples/sd3_example.py --model /cfs/dit/stable-diffusion-3-medium-diffusers --pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1 --height 1024 --width 1024 --no_use_resolution_binning --num_inference_steps 20 --warmup_steps 0 --prompt 'A small dog' --use_cfg_parallel" clear-env: needs: setup-env-and-test runs-on: [self-hosted, linux, x64] steps: - name: Remove Files - run: docker exec -w /code xfuser_test_docker sh -c "rm -r *" + run: docker exec -w /code xfuser_test_docker_${{github.repository_owner_id}} sh -c "rm -r *" - name: Destroy docker - run: docker stop xfuser_test_docker + run: docker stop xfuser_test_docker_${{github.repository_owner_id}} diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index 60d71264..eb467ba0 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -146,7 +146,22 @@ def forward( ): # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) # scatter 2, gather 1 - query = torch.cat([query, joint_tensor_query], dim=1) + supported_joint_strategy = ["none", "front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError( + f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" + ) + elif joint_strategy != "none" and joint_tensor_query is None: + raise ValueError( + f"joint_tensor_query must not be None when joint_strategy is not None" + ) + elif joint_strategy == "rear": + query = torch.cat([query, joint_tensor_query], dim=1) + elif joint_strategy == "front": + query = torch.cat([joint_tensor_query, query], dim=1) + else: + pass + ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg) ulysses_rank = torch.distributed.get_rank(self.ulysses_pg) attn_heads_per_ulysses_rank = joint_tensor_key.shape[-2] // ulysses_world_size diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 6a81f390..cb08f953 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -121,25 +121,6 @@ def __init__( assert (to_k.bias is None) == (to_v.bias is None) assert to_k.weight.shape == to_v.weight.shape - in_size, out_size = to_k.in_features, to_k.out_features - to_kv = nn.Linear( - in_size, - out_size * 2, - bias=to_k.bias is not None, - device=to_k.weight.device, - dtype=to_k.weight.dtype, - ) - to_kv.weight.data[:out_size].copy_(to_k.weight.data) - to_kv.weight.data[out_size:].copy_(to_v.weight.data) - - if to_k.bias is not None: - assert to_v.bias is not None - to_kv.bias.data[:out_size].copy_(to_k.bias.data) - to_kv.bias.data[out_size:].copy_(to_v.bias.data) - - self.to_kv = to_kv - - class xFuserAttentionProcessorRegister: _XFUSER_ATTENTION_PROCESSOR_MAPPING = {} @@ -878,7 +859,6 @@ def __call__( encoder_hidden_states ) - # kv = attn.to_kv(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -1013,12 +993,12 @@ def __init__(self): ) if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: from xfuser.core.long_ctx_attention import ( - xFuserLongContextAttention, + xFuserJointLongContextAttention, xFuserUlyssesAttention, ) if HAS_FLASH_ATTN: - self.hybrid_seq_parallel_attn = xFuserLongContextAttention( + self.hybrid_seq_parallel_attn = xFuserJointLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache ) else: @@ -1040,6 +1020,7 @@ def __call__( **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) + latent_seq_length = hidden_states.size(1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -1095,9 +1076,20 @@ def __call__( #! ---------------------------------------- ATTENTION ---------------------------------------- if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + encoder_query = query[:, :, :text_seq_length, :] + query = query[:, :, text_seq_length:, :] + encoder_key = key[:, :, :text_seq_length, :] + key = key[:, :, text_seq_length:, :] + encoder_value = value[:, :, :text_seq_length, :] + value = value[:, :, text_seq_length:, :] + query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) + hidden_states = self.hybrid_seq_parallel_attn( attn, query, @@ -1105,8 +1097,12 @@ def __call__( value, dropout_p=0.0, causal=False, - joint_strategy="none", + joint_tensor_query=encoder_query, + joint_tensor_key=encoder_key, + joint_tensor_value=encoder_value, + joint_strategy="front", ) + hidden_states = hidden_states.reshape( batch_size, -1, attn.heads * head_dim ) @@ -1141,12 +1137,14 @@ def __call__( # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) #! ---------------------------------------- ATTENTION ---------------------------------------- + assert text_seq_length + latent_seq_length == hidden_states.shape[1] # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + [text_seq_length, latent_seq_length], dim=1 ) return hidden_states, encoder_hidden_states