Skip to content

Commit

Permalink
remove dependency on flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Dec 26, 2024
1 parent 144d4e6 commit bffd020
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 28 deletions.
4 changes: 2 additions & 2 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2"
N_GPUS=4
PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 2 --ring_degree 2"

# CFG_ARGS="--use_cfg_parallel"

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ def get_cuda_version():
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang>=0.3.0",
"yunchang>=0.6.0",
"pytest",
"flask",
"opencv-python",
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.6.3",
"ray"
],
extras_require={
"diffusers": [
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"flash_attn>=2.6.3",
]
},
url="https://github.com/xdit-project/xDiT.",
Expand Down
4 changes: 0 additions & 4 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def __post_init__(self):
f"sp_degree is {self.sp_degree}, please set it "
f"to 1 or install 'yunchang' to use it"
)
if not HAS_FLASH_ATTN and self.ring_degree > 1:
raise ValueError(
f"Flash attention not found. Ring attention not available. Please set ring_degree to 1"
)


@dataclass
Expand Down
8 changes: 7 additions & 1 deletion xfuser/core/fast_attention/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from diffusers.models.attention_processor import Attention
from typing import Optional
import torch.nn.functional as F
import flash_attn

try:
import flash_attn
except ImportError:
flash_attn = None

from enum import Flag, auto
from .fast_attn_state import get_fast_attn_window_size

Expand Down Expand Up @@ -165,6 +170,7 @@ def __call__(
is_causal=False,
).transpose(1, 2)
elif method.has(FastAttnMethod.FULL_ATTN):
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
all_hidden_states = flash_attn.flash_attn_func(query, key, value)
if need_compute_residual:
# Compute the full-window attention residual
Expand Down
8 changes: 6 additions & 2 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import torch.distributed
from yunchang import LongContextAttention
from yunchang.kernels import FlashAttentionImpl
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")

from yunchang.comm.all_to_all import SeqAllToAll4D

from xfuser.logger import init_logger
Expand All @@ -22,7 +26,7 @@ def __init__(
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: FlashAttentionImpl = FlashAttentionImpl.FA,
attn_type: AttnType = AttnType.FA,
) -> None:
"""
Arguments:
Expand Down
10 changes: 8 additions & 2 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import torch
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward

from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc

try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None

def xdit_ring_flash_attn_forward(
process_group,
Expand Down Expand Up @@ -80,6 +85,7 @@ def xdit_ring_flash_attn_forward(
key, value = k, v

if not causal or step <= comm.rank:
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
Expand Down
29 changes: 15 additions & 14 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,11 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
assert yunchang.__version__ >= "0.6.0"
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -402,10 +403,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)

if get_fast_attn_enable():
Expand Down Expand Up @@ -595,10 +596,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)

def __call__(
Expand Down Expand Up @@ -796,10 +797,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -998,10 +999,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -1175,10 +1176,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down Expand Up @@ -1347,10 +1348,10 @@ def __init__(self):
use_kv_cache=self.use_long_ctx_attn_kvcache
)
else:
from yunchang.kernels import FlashAttentionImpl
from yunchang.kernels import AttnType
self.hybrid_seq_parallel_attn = xFuserLongContextAttention(
use_kv_cache=self.use_long_ctx_attn_kvcache,
attn_type=FlashAttentionImpl.TORCH,
attn_type=AttnType.TORCH,
)
else:
self.hybrid_seq_parallel_attn = None
Expand Down

0 comments on commit bffd020

Please sign in to comment.