Skip to content

Commit

Permalink
apply yunchang 0.5.1 supports ring without flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Dec 26, 2024
1 parent 167351a commit 144d4e6
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 203 deletions.
2 changes: 0 additions & 2 deletions xfuser/core/long_ctx_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .hybrid import xFuserLongContextAttention
from .ulysses import xFuserUlyssesAttention

__all__ = [
"xFuserLongContextAttention",
"xFuserUlyssesAttention",
]
3 changes: 3 additions & 0 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch.distributed
from yunchang import LongContextAttention
from yunchang.kernels import FlashAttentionImpl
from yunchang.comm.all_to_all import SeqAllToAll4D

from xfuser.logger import init_logger
Expand All @@ -21,6 +22,7 @@ def __init__(
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
) -> None:
"""
Arguments:
Expand All @@ -35,6 +37,7 @@ def __init__(
gather_idx=gather_idx,
ring_impl_type=ring_impl_type,
use_pack_qkv=use_pack_qkv,
attn_type = attn_type,
)
self.use_kv_cache = use_kv_cache
if (
Expand Down
5 changes: 0 additions & 5 deletions xfuser/core/long_ctx_attention/ulysses/__init__.py

This file was deleted.

168 changes: 0 additions & 168 deletions xfuser/core/long_ctx_attention/ulysses/attn_layer.py

This file was deleted.

42 changes: 21 additions & 21 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
Expand All @@ -196,9 +195,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -395,17 +395,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)

if get_fast_attn_enable():
Expand Down Expand Up @@ -588,17 +588,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)

def __call__(
Expand Down Expand Up @@ -789,17 +789,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -991,17 +991,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -1168,17 +1168,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -1340,17 +1340,17 @@ def __init__(self):
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
from xfuser.core.long_ctx_attention import (
xFuserLongContextAttention,
xFuserUlyssesAttention,
)

if HAS_FLASH_ATTN:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
self.hybrid_seq_parallel_attn = xFuserUlyssesAttention(
use_fa=False,
from yunchang.kernels import FlashAttentionImpl
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from diffusers.models.embeddings import CogVideoXPatchEmbed

from diffusers.models import ConsisIDTransformer3DModel
try:
from diffusers.models import ConsisIDTransformer3DModel
except ImportError:
ConsisIDTransformer3DModel = None

from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, scale_lora_layers, unscale_lora_layers

Expand Down
2 changes: 2 additions & 0 deletions xfuser/model_executor/models/transformers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def get_wrapper(cls, transformer: nn.Module) -> xFuserTransformerBaseWrapper:
origin_transformer_class,
wrapper_class,
) in cls._XFUSER_TRANSFORMER_MAPPING.items():
if origin_transformer_class is None:
continue
if isinstance(transformer, origin_transformer_class):
if (
candidate is None
Expand Down
Loading

0 comments on commit 144d4e6

Please sign in to comment.