Skip to content

Commit

Permalink
hotfix pipefusion using flash_attn (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Dec 26, 2024
1 parent 92187b8 commit 81700db
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_cuda_version():
],
extras_require={
"diffusers": [
"diffusers>=0.32.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.31.0", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"flash_attn>=2.6.3",
]
},
Expand Down
50 changes: 30 additions & 20 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
except ImportError:
flash_attn = None
_flash_attn_forward = None
from yunchang.kernels.attention import pytorch_attn_forward

def xdit_ring_flash_attn_forward(
process_group,
Expand Down Expand Up @@ -85,34 +86,43 @@ def xdit_ring_flash_attn_forward(
key, value = k, v

if not causal or step <= comm.rank:
assert flash_attn is not None, f"FlashAttention is not available, please install flash_attn"
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
if flash_attn is None:
block_out, block_lse = pytorch_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
if flash_attn.__version__ <= "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
key,
value,
dropout_p,
softmax_scale,
causal=causal and step == 0,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=0.0,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)

if step + 1 != comm.world_size:
Expand Down

0 comments on commit 81700db

Please sign in to comment.