Skip to content

Commit

Permalink
refactor: unified interface for long ctx attn (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eigensystem authored Nov 6, 2024
1 parent eb001a1 commit 2cca70b
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 248 deletions.
55 changes: 33 additions & 22 deletions tests/core/test_xfuser_attn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
import torch
import torch.distributed as dist
from xfuser.core.long_ctx_attention.ring.ring_flash_attn import xdit_ring_flash_attn_func
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, xFuserJointLongContextAttention, xFuserFluxLongContextAttention
from xfuser.core.long_ctx_attention.ring.ring_flash_attn import (
xdit_ring_flash_attn_func,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from flash_attn import flash_attn_func
import os

Expand All @@ -18,13 +20,15 @@
)


def init_dist(backend='nccl'):
def init_dist(backend="nccl"):
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

print(f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}")

print(
f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
)

torch.cuda.set_device(local_rank)
# dist.init_process_group(backend=backend)
init_distributed_environment(rank=rank, world_size=world_size)
Expand All @@ -38,11 +42,14 @@ def init_dist(backend='nccl'):
ulysses_degree = 1

initialize_model_parallel(
sequence_parallel_degree=world_size , ring_degree=ring_degree, ulysses_degree=ulysses_degree
sequence_parallel_degree=world_size,
ring_degree=ring_degree,
ulysses_degree=ulysses_degree,
)

return rank, world_size, ring_degree, ulysses_degree


class TestRingFlashAttn(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -52,13 +59,12 @@ def setUpClass(cls):
cls.seq_len = 128
cls.dtype = torch.float16


cls.rank, cls.world_size, cls.ring_degree, cls.ulysses_degree = init_dist()
cls.device = torch.device(f'cuda:{cls.rank}')
cls.device = torch.device(f"cuda:{cls.rank}")

def setUp(self):
torch.manual_seed(42 + self.rank)

@classmethod
def tearDownClass(cls):
dist.destroy_process_group()
Expand Down Expand Up @@ -91,34 +97,36 @@ def test_xfuser_attn_layer_joint_strategy_rear(self):
"""Test xFuserLongContextAttention layer in distributed mode"""
# Create test tensors
q, k, v, local_q, local_k, local_v = self._create_test_tensors()
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = self._create_test_tensors()
joint_q, joint_k, joint_v, local_joint_q, local_joint_k, local_joint_v = (
self._create_test_tensors()
)
joint_strategy = "rear"

attn = None

# Create attention layer
attn_layer = xFuserJointLongContextAttention(
attn_layer = xFuserLongContextAttention(
scatter_idx=2,
gather_idx=1,
ring_impl_type="basic",
use_kv_cache=False,
).to(device=self.device, dtype=self.dtype)

assert attn_layer.ring_pg.size() == self.ring_degree
assert attn_layer.ulysses_pg.size() == self.ulysses_degree

ref_output = flash_attn_func(
torch.cat([q, joint_q], dim=1),
torch.cat([k, joint_k], dim=1),
torch.cat([q, joint_q], dim=1),
torch.cat([k, joint_k], dim=1),
torch.cat([v, joint_v], dim=1),
dropout_p=0.0,
window_size=(-1, -1),
)

# Split ref_output into base and joint parts
base_out = ref_output[:, :self.seq_len, ::] # First half for base attention
joint_out = ref_output[:, self.seq_len:, ::] # Second half for joint attention
base_out = ref_output[:, : self.seq_len, ::] # First half for base attention
joint_out = ref_output[:, self.seq_len :, ::] # Second half for joint attention

# Get local shard for base output
base_out_shard = base_out.chunk(self.world_size, dim=1)[self.rank]
# Duplicate joint output as specified
Expand Down Expand Up @@ -153,12 +161,14 @@ def test_xfuser_attn_layer(self):
ring_impl_type="basic",
use_kv_cache=False,
).to(device=self.device, dtype=self.dtype)

assert attn_layer.ring_pg.size() == self.ring_degree
assert attn_layer.ulysses_pg.size() == self.ulysses_degree

ref_output = flash_attn_func(
q, k, v,
q,
k,
v,
dropout_p=0.0,
window_size=(-1, -1),
)
Expand All @@ -176,6 +186,7 @@ def test_xfuser_attn_layer(self):
assert torch.max(torch.abs(output - ref_output)) < 1e-3
torch.testing.assert_close(ref_output, output, rtol=1e-3, atol=1e-3)


# torchrun --nproc_per_node=4 -m unittest tests/core/test_xfuser_attn.py
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()
8 changes: 1 addition & 7 deletions xfuser/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from .cache_manager import CacheManager
from .long_ctx_attention import (
xFuserLongContextAttention,
xFuserJointLongContextAttention,
xFuserFluxLongContextAttention,
)
from .long_ctx_attention import xFuserLongContextAttention
from .utils import gpu_timer_decorator

__all__ = [
"CacheManager",
"xFuserLongContextAttention",
"xFuserJointLongContextAttention",
"xFuserFluxLongContextAttention",
"gpu_timer_decorator",
]
8 changes: 1 addition & 7 deletions xfuser/core/long_ctx_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from .hybrid import (
xFuserLongContextAttention,
xFuserFluxLongContextAttention,
xFuserJointLongContextAttention,
)
from .hybrid import xFuserLongContextAttention
from .ulysses import xFuserUlyssesAttention

__all__ = [
"xFuserLongContextAttention",
"xFuserFluxLongContextAttention",
"xFuserJointLongContextAttention",
"xFuserUlyssesAttention",
]
4 changes: 0 additions & 4 deletions xfuser/core/long_ctx_attention/hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from .attn_layer import (
xFuserLongContextAttention,
xFuserFluxLongContextAttention,
xFuserJointLongContextAttention,
)

__all__ = [
"xFuserLongContextAttention",
"xFuserFluxLongContextAttention",
"xFuserJointLongContextAttention",
]
Loading

0 comments on commit 2cca70b

Please sign in to comment.