diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index e603878a6383..911d2d5dfbf8 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -99,10 +99,11 @@ Refer to the table below for a complete list of available attention backends and | `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention | | `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention | | `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 | +| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` | | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | -| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | +| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` | | `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) | | `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention | | `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) | diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..4b34f71f16a7 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -83,12 +83,15 @@ raise ImportError( "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." ) - from ..utils.kernels_utils import _get_fa3_from_hub + from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + fa3_interface_hub = _get_fa3_from_hub() + flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func + fa_interface_hub = _get_fa_from_hub() + flash_attn_func_hub = fa_interface_hub.flash_attn_func else: flash_attn_3_func_hub = None + flash_attn_func_hub = None if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -173,6 +176,8 @@ class AttentionBackendName(str, Enum): # `flash-attn` FLASH = "flash" FLASH_VARLEN = "flash_varlen" + FLASH_HUB = "flash_hub" + # FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet. _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" @@ -403,15 +408,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) - # TODO: add support Hub variant of FA3 varlen later - elif backend in [AttentionBackendName._FLASH_3_HUB]: + # TODO: add support Hub variant of FA and FA3 varlen later + elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]: if not DIFFUSERS_ENABLE_HUB_KERNELS: raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." + f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend in [ @@ -1228,6 +1233,35 @@ def _flash_attention( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + lse = None + out = flash_attn_func_hub( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_VARLEN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972fb7..21eba3d40065 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -2,22 +2,32 @@ from .import_utils import is_kernels_available -logger = get_logger(__name__) +if is_kernels_available(): + from kernels import get_kernel +logger = get_logger(__name__) -_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_IDS = { + "fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}), + "fa": ("kernels-community/flash-attn", {}), +} -def _get_fa3_from_hub(): +def _get_from_hub(key: str): if not is_kernels_available(): return None - else: - from kernels import get_kernel - - try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub - except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") - raise + + hub_id, kwargs = _DEFAULT_HUB_IDS[key] + try: + return get_kernel(hub_id, **kwargs) + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}") + raise + + +def _get_fa3_from_hub(): + return _get_from_hub("fa3") + + +def _get_fa_from_hub(): + return _get_from_hub("fa")