Skip to content

Commit

Permalink
feat: add usp implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun committed Feb 10, 2025
1 parent b495702 commit 4bbbd1b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 20 deletions.
79 changes: 68 additions & 11 deletions xfuser/model_executor/layers/usp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# This file implements USP with torch version >= '2.5.0'
import torch
from torch.nn import functional as F
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
aten = torch.ops.aten

import torch.distributed._functional_collectives as ft_c

from torch.distributed.tensor.experimental._attention import _templated_ring_attention

from yunchang.globals import PROCESS_GROUP

from xfuser.core.distributed import (
Expand All @@ -14,17 +14,74 @@
get_ring_parallel_world_size,
)

from xfuser.envs import PACKAGES_CHECKER
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_FLASH_ATTN = env_info["has_flash_attn"]

aten = torch.ops.aten


def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
dropout_p=dropout_p,
is_causal=is_causal
)
if torch.__version__ >= "2.6.0":
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
kwargs = {
"dropout_p": dropout_p,
"is_causal": is_causal,
}
if HAS_FLASH_ATTN:
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
1,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
**kwargs,
)
else:
kwargs = {
**kwargs,
"attn_bias": None,
"compute_log_sumexp": True,
}
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
1,
aten._scaled_dot_product_efficient_attention,
query,
key,
value,
**kwargs,
)
else:
kwargs = {
"dropout_p": dropout_p,
"is_causal": is_causal,
}
if HAS_FLASH_ATTN:
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
**kwargs
)
else:
kwargs = {
**kwargs,
"attn_bias": None,
"compute_log_sumexp": True,
}
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_efficient_attention,
query,
key,
value,
**kwargs,
)
return out


Expand Down
34 changes: 25 additions & 9 deletions xfuser/model_executor/layers/usp_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,43 @@

from yunchang.globals import PROCESS_GROUP
from yunchang.ring.ring_flash_attn import ring_flash_attn_forward
from yunchang.ring.ring_pytorch_attn import ring_pytorch_attn_func

from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
)

from xfuser.envs import PACKAGES_CHECKER
env_info = PACKAGES_CHECKER.get_packages_info()
HAS_FLASH_ATTN = env_info["has_flash_attn"]


def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
query = query.transpose(1,2).contiguous()
key = key.transpose(1,2).contiguous()
value = value.transpose(1,2).contiguous()
out, *_ = ring_flash_attn_forward(
PROCESS_GROUP.RING_PG,
query,
key,
value,
softmax_scale=query.shape[-1] ** (-0.5),
dropout_p=dropout_p,
causal=is_causal,
)
if HAS_FLASH_ATTN:
out, *_ = ring_flash_attn_forward(
PROCESS_GROUP.RING_PG,
query,
key,
value,
softmax_scale=query.shape[-1] ** (-0.5),
dropout_p=dropout_p,
causal=is_causal,
)
else:
out = ring_pytorch_attn_func(
query,
key,
value,
dropout_p=dropout_p,
softmax_scale=query.shape[-1] ** (-0.5),
causal=is_causal,
group=PROCESS_GROUP.RING_PG,
)
out = out.transpose(1,2).contiguous()
return out

Expand Down

0 comments on commit 4bbbd1b

Please sign in to comment.