Skip to content

Commit 881a598

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (pytorch#153357)
Fixes pytorch#147336 ## Context NCU analysis of the fp8 flex attention perf issue in pytorch#147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM. Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown. To summarize: In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation. This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](https://github.com/triton-lang/triton/blob/81f93f2c8ec7d20a1f8184def767edeaebeb6812/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp#L403)). i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores. ## Fix summary - To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs ## Benchmarks Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime. Before fix: ``` (flex) [[email protected] ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16 2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us 2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3 2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us ``` After fix: ``` (flex) [[email protected] ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8 2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16 2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us 2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3 2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us ``` Pull Request resolved: pytorch#153357 Approved by: https://github.com/ngimel, https://github.com/davidberard98
1 parent eaf2dee commit 881a598

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

test/inductor/test_flex_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3728,6 +3728,8 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_
37283728
l_block_mask_full_q_num_blocks = L_block_mask_full_q_num_blocks
37293729
l_block_mask_full_q_indices = L_block_mask_full_q_indices
37303730
3731+
get_device_capability = torch.cuda.get_device_capability('cuda'); get_device_capability = None
3732+
37313733
score_mod_0 = self.score_mod_0
37323734
mask_fn_0 = self.mask_fn_0
37333735
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None

torch/nn/attention/flex_attention.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,62 @@ def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor):
11541154
)
11551155

11561156

1157+
def _enforce_mem_layouts(
1158+
query: Tensor, key: Tensor, value: Tensor
1159+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1160+
"""
1161+
Enforce memory layouts for query, key, and value tensors.
1162+
1163+
For non-FP8 dtypes, no action is taken.
1164+
1165+
For FP8 dtypes, we enforce the following memory layouts:
1166+
- Query tensor must be in row-major memory layout, as it will be the left-operand in the FP8 GEMM `q @ k.T`.
1167+
- Key tensor must be in row-major memory layout, as it will be transposed when used as the right-operand
1168+
in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM.
1169+
- Value tensor must be in column-major memory layout, as it will be the right-operand in the FP8 GEMM `softmax_scores @ v`.
1170+
1171+
Returns the query, key, and value tensors with the enforced memory layouts.
1172+
"""
1173+
1174+
def is_row_major(tensor: Tensor) -> bool:
1175+
return tensor.stride()[-1] == 1
1176+
1177+
def is_col_major(tensor: Tensor) -> bool:
1178+
return tensor.stride()[-2] == 1
1179+
1180+
# These memory layout constraint are only for FP8 GEMMs on architectures prior to SM100.
1181+
# SM100 has support for TN, NT, TT, NN layouts for FP8 GEMM
1182+
# (i.e., left and right operands can be in row or column major layouts)
1183+
# so this check is only needed for older architectures.
1184+
# See: https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md
1185+
fp8_dtypes = (
1186+
torch.float8_e4m3fn,
1187+
torch.float8_e5m2,
1188+
)
1189+
gemm_precision = query.dtype
1190+
is_sm100_or_greater = (
1191+
torch.cuda.is_available()
1192+
and torch.version.cuda is not None
1193+
and torch.cuda.get_device_capability("cuda") >= (10, 0)
1194+
)
1195+
if gemm_precision not in fp8_dtypes or not is_sm100_or_greater:
1196+
return query, key, value
1197+
1198+
# Query must be in row-major memory layout as the left-operand in the FP8 GEMM `q @ k.T`
1199+
if not is_row_major(query):
1200+
query = query.contiguous()
1201+
1202+
# Key must be in row-major memory layout as it will be transposed when used as the right-operand
1203+
# in the FP8 GEMM `q @ k.T`, meaning it will correctly be in column-major memory layout for the GEMM.
1204+
if not is_row_major(key):
1205+
key = key.contiguous()
1206+
1207+
# Value must be in column-major memory layout as the right-operand in the FP8 GEMM `softmax_scores @ v`
1208+
if not is_col_major(value):
1209+
value = value.transpose(-2, -1).contiguous().transpose(-2, -1)
1210+
return query, key, value
1211+
1212+
11571213
def flex_attention(
11581214
query: Tensor,
11591215
key: Tensor,
@@ -1191,9 +1247,9 @@ def score_mod(
11911247
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
11921248
11931249
Args:
1194-
query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`.
1195-
key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`.
1196-
value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`.
1250+
query (Tensor): Query tensor; shape :math:`(B, Hq, L, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance.
1251+
key (Tensor): Key tensor; shape :math:`(B, Hkv, S, E)`. For FP8 dtypes, should be in row-major memory layout for optimal performance.
1252+
value (Tensor): Value tensor; shape :math:`(B, Hkv, S, Ev)`. For FP8 dtypes, should be in column-major memory layout for optimal performance.
11971253
score_mod (Optional[Callable]): Function to modify attention scores. By default no score_mod is applied.
11981254
block_mask (Optional[BlockMask]): BlockMask object that controls the blocksparsity pattern of the attention.
11991255
scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`.
@@ -1222,6 +1278,7 @@ def score_mod(
12221278
_validate_embed_dim(query, key, value)
12231279
_validate_device(query, key, value)
12241280
_validate_nestedness(query, key, value)
1281+
query, key, value = _enforce_mem_layouts(query, key, value)
12251282
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
12261283
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
12271284
if (not enable_gqa) and query.size(-3) != key.size(-3):

0 commit comments

Comments
 (0)