diff --git a/flashinfer/cute_dsl/blockwise_gemm.py b/flashinfer/cute_dsl/blockwise_gemm.py new file mode 100644 index 0000000000..95d0f35a04 --- /dev/null +++ b/flashinfer/cute_dsl/blockwise_gemm.py @@ -0,0 +1,2965 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import functools +import math +from typing import Callable, List, Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import torch + +from flashinfer.utils import get_compute_capability +from .utils import get_cutlass_dtype, cutlass_to_torch_dtype, get_num_sm, make_ptr + + +""" +High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") +- Matrix B is NxKxL, L is batch dimension, B can be column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Each block will apply the scale factor A +- Each row will apply the scale factor B +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using LDGSTS operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below BlockwiseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class BlockwiseGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, e5m2) + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = BlockwiseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.epilog_warp_id = ( + 4, + 5, + 6, + 7, + ) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.sched_sync_bar_id = 3 + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # TODO: get from args + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_scale_stage * 2 + ] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_tile_stage * 2 + ] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), # type: ignore[attr-defined] + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_wo_sched, self.threads_wo_sched + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # store the tile info + cur_tile_coord = work_tile.tile_idx + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2] + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier( + barrier_id=self.sched_sync_bar_id, + number_of_threads=self.threads_per_warp, + ) + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_fragment(cute.make_layout((4,)).shape, cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_fragment(cute.make_layout((4,)).shape, cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_fragment( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tAsSFA, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_fragment( + cute.make_layout( + cute.filter_zeros( + cute.slice_(tBsSFB, (None, None, None, 0)) + ).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + # TODO: Skip more unnecessary load + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire( + scale_producer_state, peek_scale_empty_status + ) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len( + (self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id) + ) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_fragment(cute.make_layout((4,)).shape, cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire( + acc_producer_state, peek_acc_empty_status + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len( + (self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id) + ) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, 1 + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_fragment(cute.make_layout((4,)).shape, cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_fragment( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_fragment( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait( + scale_consumer_state, peek_scale_full_status + ) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[ + (None, None, None, subtile_idx) + ] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait( + acc_consumer_state + ) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len( + (self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id) + ) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = None + tiled_copy_r2s = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, 1 + ) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_fragment(cute.make_layout((4,)).shape, cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + mma_tile_coord_mnl[2], + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, epi_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + elif cutlass.const_expr(self.mma_tiler[0] == 128): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + else: + # default: 16dp + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(1)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide( + tAcc_final[((None, None), 0, 0, None)], epi_tile + ) + + tiled_copy_t2r = tcgen05.make_tmem_copy( + tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)] + ) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_fragment( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes( + tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_) + ) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_fragment( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes( + tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_) + ) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy + - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockwiseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + +class BlockwiseGemmCuteDSL: + def __init__( + self, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ab_dtype: torch.dtype, + sf_dtype: torch.dtype, + c_dtype: torch.dtype, + acc_dtype: torch.dtype, + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sm_count: int, + sm_version: str, + ): + self._m = m + self._n = n + self._k = k + self._l = l + self._a_major = a_major + self._b_major = b_major + self._c_major = c_major + self._ab_dtype = ab_dtype + self._sf_dtype = sf_dtype + self._acc_dtype = acc_dtype + self._c_dtype = c_dtype + self._use_2cta_instrs = use_2cta_instrs + self._mma_tiler_mn = mma_tiler_mn + self._cluster_shape_mn = cluster_shape_mn + + if not BlockwiseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + hardware_info = cutlass.utils.HardwareInfo() + self._max_active_clusters = min( + hardware_info.get_max_active_clusters( + self._cluster_shape_mn[0] * self._cluster_shape_mn[1] + ), + sm_count, + ) + self._sm_version = sm_version + + @cute.jit + def __call__( + self, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + current_stream: cuda.CUstream, + ): + a_tensor = cute.make_tensor( + a_ptr, + layout=cute.make_ordered_layout( + (self._m, self._k, self._l), + order=(0, 1, 2) if self._a_major == "m" else (1, 0, 2), + ), + ) + b_tensor = cute.make_tensor( + b_ptr, + layout=cute.make_ordered_layout( + (self._n, self._k, self._l), + order=(0, 1, 2) if self._b_major == "n" else (1, 0, 2), + ), + ) + c_tensor = cute.make_tensor( + c_ptr, + layout=cute.make_ordered_layout( + (self._m, self._n, self._l), + order=(0, 1, 2) if self._c_major == "m" else (1, 0, 2), + ), + ) + sfa_tensor = cute.make_tensor( + sfa_ptr, + layout=cute.make_ordered_layout( + (self._m, math.ceil(self._k / 128), self._l), + order=(0, 1, 2), + ), + ) + sfb_tensor = cute.make_tensor( + sfb_ptr, + layout=cute.make_ordered_layout( + (math.ceil(self._n / 128), math.ceil(self._k / 128), self._l), + order=(1, 0, 2), + ), + ) + + BlockwiseGemmKernel( + acc_dtype=self._acc_dtype, + use_2cta_instrs=self._use_2cta_instrs, + mma_tiler_mn=self._mma_tiler_mn, + cluster_shape_mn=self._cluster_shape_mn, + )( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + self._max_active_clusters, + current_stream, + ) + + +@functools.cache +def get_cute_dsl_compiled_blockwise_gemm_kernel( + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + sm_count: int, + sm_version: str, +) -> Callable: + def get_cute_pointers( + input_tensors: Optional[List[torch.tensor]], + ) -> List[cute.Pointer]: + if input_tensors is None: + ( + a_data_ptr, + b_data_ptr, + sfa_data_ptr, + sfb_data_ptr, + c_data_ptr, + ) = [16 for _ in range(5)] + else: + ( + a_tensor_gpu, + b_tensor_gpu, + sfa_tensor_gpu, + sfb_tensor_gpu, + c_tensor_gpu, + ) = input_tensors + + ( + a_data_ptr, + b_data_ptr, + sfa_data_ptr, + sfb_data_ptr, + c_data_ptr, + ) = ( + a_tensor_gpu.data_ptr(), + b_tensor_gpu.data_ptr(), + sfa_tensor_gpu.data_ptr(), + sfb_tensor_gpu.data_ptr(), + c_tensor_gpu.data_ptr(), + ) + + a_ptr = make_ptr( + ab_dtype, + a_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + ab_dtype, + b_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + sfa_ptr = make_ptr( + sf_dtype, + sfa_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + sfb_ptr = make_ptr( + sf_dtype, + sfb_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_ptr = make_ptr( + c_dtype, + c_data_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + return [a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr] + + kernel = cute.compile( + BlockwiseGemmCuteDSL( + m=m, + n=n, + k=k, + l=l, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sm_count=sm_count, + sm_version=sm_version, + ), + *get_cute_pointers(None), + cutlass_torch.current_stream(), + ) + + def tensor_api( + a_tensor_gpu: torch.Tensor, + b_tensor_gpu: torch.Tensor, + sfa_tensor_gpu: torch.Tensor, + sfb_tensor_gpu: torch.Tensor, + c_tensor_gpu: Optional[torch.Tensor] = None, + ): + if c_tensor_gpu is None: + c_tensor_gpu = torch.empty( + (l, m, n), + dtype=cutlass_to_torch_dtype(c_dtype), + device="cuda", + ) + + current_stream = cutlass_torch.current_stream() + + nonlocal kernel + kernel( + *get_cute_pointers( + [ + a_tensor_gpu, + b_tensor_gpu, + sfa_tensor_gpu, + sfb_tensor_gpu, + c_tensor_gpu, + ] + ), + current_stream, + ) + return c_tensor_gpu + + return tensor_api + + +def blockwise_gemm( + a_torch: torch.Tensor, + sfa_torch: torch.Tensor, + b_torch: torch.Tensor, + sfb_torch: torch.Tensor, + c_torch: torch.Tensor, + *, + ab_dtype: str, + sf_dtype: str, + c_dtype: str, + acc_dtype: str, + sm_count: Optional[int] = None, + **kwargs, +): + m, k, l = a_torch.shape + n, _, _ = b_torch.shape + + mma_tiler_mn = kwargs.pop("mma_tiler_mn", (128, 128)) + cluster_shape_mn = kwargs.pop("cluster_shape_mn", (1, 1)) + if sm_count is None: + sm_count = get_num_sm(a_torch.device) + use_2cta_instrs = kwargs.pop("use_2cta_instrs", False) + assert len(kwargs) == 0, f"Unsupported kwargs: {kwargs}" + + major, minor = get_compute_capability(a_torch.device) + if major == 11 and minor == 0: + raise ValueError("SM110 is not supported for cute-dsl backend.") + + return get_cute_dsl_compiled_blockwise_gemm_kernel( + m=m, + n=n, + k=k, + l=l, + a_major="k", + b_major="k", + c_major="n", + ab_dtype=get_cutlass_dtype(ab_dtype), + sf_dtype=get_cutlass_dtype(sf_dtype), + c_dtype=get_cutlass_dtype(c_dtype), + acc_dtype=get_cutlass_dtype(acc_dtype), + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + sm_count=sm_count, + sm_version=f"sm_{major}{minor}", + )( + a_tensor_gpu=a_torch, + b_tensor_gpu=b_torch, + sfa_tensor_gpu=sfa_torch, + sfb_tensor_gpu=sfb_torch, + c_tensor_gpu=c_torch, + ) diff --git a/tests/gemm/test_cute_dsl_blockwise_gemm.py b/tests/gemm/test_cute_dsl_blockwise_gemm.py new file mode 100644 index 0000000000..a34514e3c9 --- /dev/null +++ b/tests/gemm/test_cute_dsl_blockwise_gemm.py @@ -0,0 +1,251 @@ +import math +import pytest +from typing import Tuple + +import cutlass +import cutlass.torch as cutlass_torch +import torch + +from flashinfer.cute_dsl.blockwise_gemm import BlockwiseGemmKernel, blockwise_gemm +from flashinfer.cute_dsl.utils import ( + get_cutlass_dtype, + get_num_sm, + is_cute_dsl_available, +) + + +def create_tensors( + l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, scale_dtype, device +): + torch.manual_seed(42) + + a_torch_cpu = cutlass_torch.matrix( + l, m, k, a_major == "m", get_cutlass_dtype(ab_dtype), device=device + ) + b_torch_cpu = cutlass_torch.matrix( + l, n, k, b_major == "n", get_cutlass_dtype(ab_dtype), device=device + ) + c_torch_cpu = cutlass_torch.matrix( + l, m, n, cd_major == "m", get_cutlass_dtype(c_dtype), device=device + ) + sfa_torch_cpu = cutlass_torch.matrix( + l, m, math.ceil(k / 128), True, get_cutlass_dtype(scale_dtype), device=device + ) + sfb_torch_cpu = cutlass_torch.matrix( + l, + math.ceil(n / 128), + math.ceil(k / 128), + False, + get_cutlass_dtype(scale_dtype), + device=device, + ) + + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_torch_cpu, + get_cutlass_dtype(ab_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_torch_cpu, + get_cutlass_dtype(ab_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c_torch_cpu, + get_cutlass_dtype(c_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + sfa_tensor, sfa_torch = cutlass_torch.cute_tensor_like( + sfa_torch_cpu, + get_cutlass_dtype(scale_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + sfb_tensor, sfb_torch = cutlass_torch.cute_tensor_like( + sfb_torch_cpu, + get_cutlass_dtype(scale_dtype), + is_dynamic_layout=True, + assumed_align=16, + ) + + return ( + a_tensor, + a_torch, + b_tensor, + b_torch, + c_tensor, + c_torch, + sfa_tensor, + sfa_torch, + sfb_tensor, + sfb_torch, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + ) + + +@pytest.mark.skipif( + not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`" +) +@pytest.mark.parametrize("lm", [(1, 256)]) +@pytest.mark.parametrize("kn", [(512, 256)]) +@pytest.mark.parametrize( + "ab_dtype,sf_dtype,c_dtype,acc_dtype", + [ + ("float8_e4m3fn", "float32", "bfloat16", "float32"), + ], +) +@pytest.mark.parametrize("a_major", ["k"]) +@pytest.mark.parametrize("b_major", ["k"]) +@pytest.mark.parametrize("c_major", ["n"]) +@pytest.mark.parametrize("use_2cta_instrs", [False]) +@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)]) +@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)]) +@pytest.mark.parametrize("tolerance", [1e-01]) +@pytest.mark.parametrize("iterations", [3]) +def test_blockwise_gemm_python_interface( + lm: Tuple[int, int], + kn: Tuple[int, int], + ab_dtype: cutlass.dtype, + sf_dtype: cutlass.dtype, + c_dtype: cutlass.dtype, + acc_dtype: cutlass.dtype, + a_major: str, + b_major: str, + c_major: str, + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float, + iterations: int, +): + torch.manual_seed(42) + device = torch.device("cuda:0") + major, minor = torch.cuda.get_device_capability(device) + + if not (major == 10 and minor == 0): + pytest.skip("Cute-dsl backend is only supported on SM100.") + + l, m = lm + k, n = kn + + sm_count = get_num_sm(device) + + print(f"device: {device}") + + if not BlockwiseGemmKernel.can_implement( + get_cutlass_dtype(ab_dtype), + get_cutlass_dtype(acc_dtype), + get_cutlass_dtype(c_dtype), + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + pytest.skip( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {c_dtype}, {acc_dtype}, {use_2cta_instrs} ,{mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + ( + a_tensor, + a_torch, + b_tensor, + b_torch, + c_tensor, + c_torch, + sfa_tensor, + sfa_torch, + sfb_tensor, + sfb_torch, + a_torch_cpu, + b_torch_cpu, + c_torch_cpu, + sfa_torch_cpu, + sfb_torch_cpu, + ) = create_tensors( + l, m, n, k, a_major, b_major, c_major, ab_dtype, c_dtype, sf_dtype, device + ) + + for _ in range(iterations): + blockwise_gemm( + a_torch, + sfa_torch, + b_torch, + sfb_torch, + c_torch, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + sm_count=sm_count, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + use_2cta_instrs=use_2cta_instrs, + ) + + torch.cuda.synchronize() + + def pad_and_multiply(scale, tensor): + cm, ck, _ = scale.shape + m, k, _ = tensor.shape + IsGroupWise = False + IsBlockWise = False + if ck == math.ceil(k / 128): + IsGroupWise = True + if cm == math.ceil(m / 128): + IsBlockWise = True + if not IsBlockWise and not IsGroupWise: + raise ValueError("Only support granularity = 128") + + k_idx = torch.arange(k, device=scale.device) + if IsGroupWise: + k_idx = k_idx // 128 + m_idx = torch.arange(m, device=scale.device) + if IsBlockWise: + m_idx = m_idx // 128 + expanded_scale = scale[m_idx[:, None], k_idx, :] + + result = expanded_scale * tensor + + return result + + updated_a = pad_and_multiply(sfa_torch_cpu, a_torch_cpu) + updated_b = pad_and_multiply(sfb_torch_cpu, b_torch_cpu) + + ref = torch.einsum("mkl,nkl->mnl", updated_a, updated_b).to( + cutlass_torch.dtype(get_cutlass_dtype(c_dtype)) + ) + res = c_torch.view(cutlass_torch.dtype(get_cutlass_dtype(c_dtype))) + + torch.testing.assert_close(res.cpu(), ref.cpu(), atol=tolerance, rtol=1e-03) + + +if __name__ == "__main__": + test_blockwise_gemm_python_interface( + lm=(1, 256), + kn=(512, 256), + ab_dtype="float8_e4m3fn", + sf_dtype="float32", + c_dtype="bfloat16", + acc_dtype="float32", + a_major="k", + b_major="k", + c_major="n", + use_2cta_instrs=False, + mma_tiler_mn=(128, 128), + cluster_shape_mn=(1, 1), + tolerance=1e-01, + iterations=3, + )