@@ -86,21 +86,6 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
8686 barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
8787 m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
8888 )
89- #print("LOOK HERE", (barrier_size,))
90- # NOTE: use_2cta_instrs from blockedscaled_gemm logic
91-
92- # use_2cta_instrs = mma_tiler_mn[0] == 256
93- # cta_tile_shape_mn = (
94- # mma_tiler_mn[0] // (2 if use_2cta_instrs else 1),
95- # mma_tiler_mn[1],
96- # )
97- # problem_shape_ntile_mn = (m // cta_tile_shape_mn[0], n // cta_tile_shape_mn[1])
98- # num_tiles_per_batch = problem_shape_ntile_mn[0] * problem_shape_ntile_mn[1]
99- # num_tiles = num_tiles_per_batch * l
100- # num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
101- # +num_sms for final barrier
102- # num_tiles + num_sms
103-
10489 barrier_flag = symm_mem .empty ((barrier_size ,), device = "cuda" , dtype = torch .int32 )
10590
10691 barrier_flag .fill_ (0 )
@@ -158,8 +143,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
158143 l , m = lm
159144 k , n = kn
160145
161- #print(f"device: {device}")
162-
163146 if not Sm100BlockScaledPersistentDenseGemmKernel .can_implement (
164147 get_cutlass_dtype (ab_dtype ),
165148 get_cutlass_dtype (sf_dtype ),
@@ -201,7 +184,6 @@ def run_blockscaled_gemm_all_reduce_python_interface(
201184 init_type = cutlass_torch .TensorInitType .SCALAR ,
202185 init_config = cutlass_torch .ScalarInitConfig (value = 0.0 ),
203186 )
204- #print(f"Rank {rank}: c_ref INITIAL shape={c_ref.shape}, stride={c_ref.stride()}")
205187 a_tensor , a_torch = cutlass_torch .cute_tensor_like (
206188 a_ref ,
207189 get_cutlass_dtype (ab_dtype ),
@@ -214,21 +196,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
214196 is_dynamic_layout = True ,
215197 assumed_align = 16 ,
216198 )
217- # c_tensor, c_torch = cutlass_torch.cute_tensor_like(
218- # c_ref,
219- # get_cutlass_dtype(c_dtype),
220- # is_dynamic_layout=True,
221- # assumed_align=16,
222- # )
223199 c_tensor , c_tensor_mc , c_torch , c_torch_mc = create_mc_tensor (
224200 c_ref ,
225201 get_cutlass_dtype (c_dtype ),
226202 # (1 if c_major == "n" else 0),
227203 is_dynamic_layout = True ,
228204 )
229- # print(
230- # f"Rank {rank}: c_torch INITIAL shape={c_torch.shape}, stride={c_torch.stride()}"
231- # )
232205 alpha_tensor = (
233206 torch .randn (l , dtype = torch .float32 , device = device ) if fuse_alpha else None
234207 )
@@ -271,15 +244,11 @@ def run_blockscaled_gemm_all_reduce_python_interface(
271244 sfb_ref , sfb_tensor , sfb_torch = create_scale_factor_tensor (
272245 l , n , k , sf_vec_size , get_cutlass_dtype (sf_dtype ), device
273246 )
274- # masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
275247 if rank == 0 :
276248 masked_m_tensor = torch .randint (0 , m , (l ,), dtype = torch .int32 , device = device )
277249 else :
278250 masked_m_tensor = torch .empty ((l ,), dtype = torch .int32 , device = device )
279251 torch .distributed .broadcast (masked_m_tensor , src = 0 )
280- # to hack and test:
281- # masked_m_tensor = torch.full((l,), m, dtype=torch.int32, device=device)
282- # print(f"Rank {rank}: masked_m = {masked_m_tensor}")
283252 for _ in range (iterations ):
284253 dst_signals = (
285254 torch .zeros ((l ,), dtype = torch .uint32 , device = "cuda" )
@@ -328,18 +297,12 @@ def run_blockscaled_gemm_all_reduce_python_interface(
328297 )
329298 # Convert c back to f32 for comparison.
330299 ref = ref .permute (2 , 0 , 1 ).contiguous ().permute (1 , 2 , 0 )
331- # print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
332- # print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
333- # print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
334300 cute .testing .convert (
335301 c_tensor ,
336302 from_dlpack (c_ref , assumed_align = 16 ).mark_layout_dynamic (
337303 leading_dim = (1 if c_major == "n" else 0 )
338304 ),
339305 )
340- # print(f"Rank {rank}: c_ref shape={c_ref.shape}, stride={c_ref.stride()}")
341- # print(f"Rank {rank}: ref shape={ref.shape}, stride={ref.stride()}")
342- # print(f"Rank {rank}: c_torch shape={c_torch.shape}, stride={c_torch.stride()}")
343306 if c_dtype in ("float32" , "float16" , "bfloat16" ):
344307 for i in range (l ):
345308 # skip testing c_ref & ref
@@ -481,23 +444,23 @@ def multi_process_parallel(
481444@pytest .mark .parametrize (
482445 "ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
483446 [
484- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485- # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486- # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487- # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488- # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489- # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490- # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491- # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492- # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493- # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494- ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
495- # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496- # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497- # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498- # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
447+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
448+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float16" , 16 ),
449+ ("float4_e2m1fn" , "float8_e8m0fnu" , "bfloat16" , 16 ),
450+ ("float4_e2m1fn" , "float8_e8m0fnu" , "float32" , 16 ),
451+ ("float4_e2m1fn" , "float8_e4m3fn" , "float16" , 16 ),
452+ ("float4_e2m1fn" , "float8_e4m3fn" , "bfloat16" , 16 ),
453+ ("float4_e2m1fn" , "float8_e4m3fn" , "float32" , 16 ),
454+ ("float8_e4m3fn" , "float8_e8m0fnu" , "bfloat16" , 32 ),
455+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float16" , 32 ),
456+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float32" , 32 ),
457+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
458+ ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
459+ ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
460+ ("float8_e5m2" , "float8_e8m0fnu" , "float16" , 32 ),
461+ ("float8_e5m2" , "float8_e8m0fnu" , "float32" , 32 ),
499462 # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500- # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
463+ ("float8_e5m2" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
501464 ],
502465)
503466@pytest .mark .parametrize ("a_major" , ["k" ])
@@ -538,7 +501,6 @@ def test_cute_dsl_blockscaled_gemm_allreduce_two_shot(
538501 pytest .skip (
539502 f"world_size { world_size } is greater than available_gpus { available_gpus } "
540503 )
541- #device = torch.device("cuda", rank)
542504 major , minor = torch .cuda .get_device_capability (torch .device ("cuda:0" ))
543505 if not (major == 10 and minor == 0 ):
544506 pytest .skip ("Cute-dsl backend is only supported on SM100." )
0 commit comments