diff --git a/examples/run.sh b/examples/run.sh index 12897814..1b5a490b 100644 --- a/examples/run.sh +++ b/examples/run.sh @@ -29,8 +29,8 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" # On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch) -N_GPUS=8 -PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2" +N_GPUS=4 +PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 2 --ring_degree 2" # CFG_ARGS="--use_cfg_parallel" diff --git a/setup.py b/setup.py index 87a5d7fe..c933a467 100644 --- a/setup.py +++ b/setup.py @@ -32,19 +32,19 @@ def get_cuda_version(): "sentencepiece>=0.1.99", "beautifulsoup4>=4.12.3", "distvae", - "yunchang>=0.3.0", + "yunchang>=0.6.0", "pytest", "flask", "opencv-python", "imageio", "imageio-ffmpeg", "optimum-quanto", - "flash_attn>=2.6.3", "ray" ], extras_require={ "diffusers": [ - "diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux + "diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux + "flash_attn>=2.6.3", ] }, url="https://github.com/xdit-project/xDiT.", diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 801949ea..ce1c87f8 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -130,10 +130,6 @@ def __post_init__(self): f"sp_degree is {self.sp_degree}, please set it " f"to 1 or install 'yunchang' to use it" ) - if not HAS_FLASH_ATTN and self.ring_degree > 1: - raise ValueError( - f"Flash attention not found. Ring attention not available. Please set ring_degree to 1" - ) @dataclass diff --git a/xfuser/core/fast_attention/attn_layer.py b/xfuser/core/fast_attention/attn_layer.py index e82c4bfa..0d306d40 100644 --- a/xfuser/core/fast_attention/attn_layer.py +++ b/xfuser/core/fast_attention/attn_layer.py @@ -7,7 +7,12 @@ from diffusers.models.attention_processor import Attention from typing import Optional import torch.nn.functional as F -import flash_attn + +try: + import flash_attn +except ImportError: + flash_attn = None + from enum import Flag, auto from .fast_attn_state import get_fast_attn_window_size @@ -165,6 +170,7 @@ def __call__( is_causal=False, ).transpose(1, 2) elif method.has(FastAttnMethod.FULL_ATTN): + assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn" all_hidden_states = flash_attn.flash_attn_func(query, key, value) if need_compute_residual: # Compute the full-window attention residual diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index 1190e467..a459630a 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -3,7 +3,11 @@ import torch.distributed from yunchang import LongContextAttention -from yunchang.kernels import FlashAttentionImpl +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") + from yunchang.comm.all_to_all import SeqAllToAll4D from xfuser.logger import init_logger @@ -22,7 +26,7 @@ def __init__( ring_impl_type: str = "basic", use_pack_qkv: bool = False, use_kv_cache: bool = False, - attn_type: FlashAttentionImpl = FlashAttentionImpl.FA, + attn_type: AttnType = AttnType.FA, ) -> None: """ Arguments: diff --git a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py index 9e8b116d..a4e8a501 100644 --- a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py +++ b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py @@ -1,11 +1,16 @@ import torch -import flash_attn -from flash_attn.flash_attn_interface import _flash_attn_forward + from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.cache_manager.cache_manager import get_cache_manager from yunchang.ring.utils import RingComm, update_out_and_lse from yunchang.ring.ring_flash_attn import RingFlashAttnFunc +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward +except ImportError: + flash_attn = None + _flash_attn_forward = None def xdit_ring_flash_attn_forward( process_group, @@ -80,6 +85,7 @@ def xdit_ring_flash_attn_forward( key, value = k, v if not causal or step <= comm.rank: + assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn" if flash_attn.__version__ <= "2.6.3": block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( q, diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index c6223e5e..2deaca63 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -195,10 +195,11 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType + assert yunchang.__version__ >= "0.6.0" self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -402,10 +403,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) if get_fast_attn_enable(): @@ -595,10 +596,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) def __call__( @@ -796,10 +797,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -998,10 +999,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -1175,10 +1176,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None @@ -1347,10 +1348,10 @@ def __init__(self): use_kv_cache=self.use_long_ctx_attn_kvcache ) else: - from yunchang.kernels import FlashAttentionImpl + from yunchang.kernels import AttnType self.hybrid_seq_parallel_attn = xFuserLongContextAttention( use_kv_cache=self.use_long_ctx_attn_kvcache, - attn_type=FlashAttentionImpl.TORCH, + attn_type=AttnType.TORCH, ) else: self.hybrid_seq_parallel_attn = None