2626# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
29- from typing import Optional , Tuple , Type , Union
29+ import functools
30+ from typing import Callable , Optional , Tuple , Type , Union , List
3031
3132import cuda .bindings .driver as cuda
3233import cutlass
3738import cutlass .utils .blackwell_helpers as sm100_utils
3839import cutlass .utils .blockscaled_layout as blockscaled_utils
3940import cutlass .utils .distributed_helpers as distributed_helpers
40- import torch
41- import functools
4241from cutlass ._mlir import ir
4342from cutlass .cute .nvgpu import cpasync , tcgen05
4443from cutlass .cute .runtime import from_dlpack
45-
4644from cutlass .cutlass_dsl import (
47- Int32 ,
48- Int64 ,
49- Uint8 ,
50- Uint64 ,
5145 T ,
5246 Integer ,
5347 dsl_user_op ,
5448 extract_mlir_values ,
5549 new_from_mlir_values ,
5650)
57-
5851from cutlass .cute .typing import (
5952 Int32 ,
53+ Int64 ,
54+ Uint8 ,
55+ Uint64 ,
6056 Float16 ,
6157 BFloat16 ,
6258 Float32 ,
6561 Tensor ,
6662)
6763from cutlass ._mlir .dialects import llvm
68- from flashinfer .utils import get_compute_capability
6964from cutlass .utils .static_persistent_tile_scheduler import WorkTileInfo
65+ import torch
66+
67+ from flashinfer .utils import get_compute_capability
7068from .utils import get_cutlass_dtype , cutlass_to_torch_dtype , get_num_sm , make_ptr
71- from typing import Callable , List
7269
7370
7471sizeof_i32 = 4
@@ -1865,7 +1862,6 @@ def kernel(
18651862 # Allreduce
18661863 #
18671864 if cutlass .const_expr (self .all_reduce == "two_shot" ):
1868-
18691865 tile_id = Int32 (
18701866 tile_sched ._current_work_linear_idx
18711867 * cute .size (self .cluster_shape_mn )
@@ -2950,13 +2946,15 @@ def __call__(
29502946 current_stream : cuda .CUstream ,
29512947 ):
29522948 if cutlass .const_expr (self ._all_reduce != "none" ):
2953- barrier_flag_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
2954- self ._m ,
2955- self ._n ,
2956- self ._l ,
2957- self ._mma_tiler_mn ,
2958- self ._cluster_shape_mn ,
2959- self ._max_active_clusters ,
2949+ barrier_flag_size = (
2950+ Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
2951+ self ._m ,
2952+ self ._n ,
2953+ self ._l ,
2954+ self ._mma_tiler_mn ,
2955+ self ._cluster_shape_mn ,
2956+ self ._max_active_clusters ,
2957+ )
29602958 )
29612959 else :
29622960 barrier_flag_size = 1 # Dummy size when not used
@@ -2982,21 +2980,33 @@ def __call__(
29822980 order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
29832981 ),
29842982 )
2985- c_mc_tensor = cute .make_tensor (
2986- c_mc_ptr ,
2987- layout = cute .make_ordered_layout (
2988- (self ._m , self ._n , self ._l ),
2989- order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
2990- ),
2991- ) if c_mc_ptr is not None else None
2992- barrier_flag_tensor = cute .make_tensor (
2993- barrier_flag_ptr ,
2994- layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2995- ) if barrier_flag_ptr is not None else None
2996- barrier_flag_mc_tensor = cute .make_tensor (
2997- barrier_flag_mc_ptr ,
2998- layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2999- ) if barrier_flag_mc_ptr is not None else None
2983+ c_mc_tensor = (
2984+ cute .make_tensor (
2985+ c_mc_ptr ,
2986+ layout = cute .make_ordered_layout (
2987+ (self ._m , self ._n , self ._l ),
2988+ order = (0 , 1 , 2 ) if self ._c_major == "m" else (1 , 0 , 2 ),
2989+ ),
2990+ )
2991+ if c_mc_ptr is not None
2992+ else None
2993+ )
2994+ barrier_flag_tensor = (
2995+ cute .make_tensor (
2996+ barrier_flag_ptr ,
2997+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
2998+ )
2999+ if barrier_flag_ptr is not None
3000+ else None
3001+ )
3002+ barrier_flag_mc_tensor = (
3003+ cute .make_tensor (
3004+ barrier_flag_mc_ptr ,
3005+ layout = cute .make_ordered_layout ((barrier_flag_size ,), order = (0 ,)),
3006+ )
3007+ if barrier_flag_mc_ptr is not None
3008+ else None
3009+ )
30003010
30013011 # calculate sf_tensor shape and order
30023012 def ceil_div (a , b ):
@@ -3154,7 +3164,6 @@ def get_cute_pointers(
31543164 c_mc_data_ptr ,
31553165 barrier_flag_data_ptr ,
31563166 barrier_flag_mc_data_ptr ,
3157-
31583167 ) = (
31593168 a_tensor_gpu .data_ptr (),
31603169 b_tensor_gpu .data_ptr (),
@@ -3168,7 +3177,9 @@ def get_cute_pointers(
31683177 alpha_tensor_gpu .data_ptr () if alpha_tensor_gpu is not None else None ,
31693178 c_mc_gpu .data_ptr () if c_mc_gpu is not None else None ,
31703179 barrier_flag_gpu .data_ptr () if barrier_flag_gpu is not None else None ,
3171- barrier_flag_mc_gpu .data_ptr () if barrier_flag_mc_gpu is not None else None ,
3180+ barrier_flag_mc_gpu .data_ptr ()
3181+ if barrier_flag_mc_gpu is not None
3182+ else None ,
31723183 )
31733184
31743185 a_ptr = make_ptr (
0 commit comments