@@ -86,7 +86,7 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
8686 barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
8787 m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
8888 )
89- print ("LOOK HERE" , (barrier_size ,))
89+ # print("LOOK HERE", (barrier_size,))
9090 # NOTE: use_2cta_instrs from blockedscaled_gemm logic
9191
9292 # use_2cta_instrs = mma_tiler_mn[0] == 256
@@ -481,23 +481,23 @@ def multi_process_parallel(
481481@pytest .mark .parametrize (
482482 "ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
483483 [
484- ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 )
485- ("float4_e2m1fn" , "float8_e8m0fnu" , "float16" , 16 ),
486- ("float4_e2m1fn" , "float8_e8m0fnu" , "bfloat16" , 16 ),
487- ("float4_e2m1fn" , "float8_e8m0fnu" , "float32" , 16 ),
488- ("float4_e2m1fn" , "float8_e4m3fn" , "float16" , 16 ),
489- ("float4_e2m1fn" , "float8_e4m3fn" , "bfloat16" , 16 ),
490- ("float4_e2m1fn" , "float8_e4m3fn" , "float32" , 16 ),
491- ("float8_e4m3fn" , "float8_e8m0fnu" , "bfloat16" , 32 ),
492- ("float8_e4m3fn" , "float8_e8m0fnu" , "float16" , 32 ),
493- ("float8_e4m3fn" , "float8_e8m0fnu" , "float32" , 32 ),
484+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485+ # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486+ # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487+ # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488+ # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489+ # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490+ # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491+ # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492+ # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493+ # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494494 ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
495- ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
496- ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
497- ("float8_e5m2" , "float8_e8m0fnu" , "float16" , 32 ),
498- ("float8_e5m2" , "float8_e8m0fnu" , "float32" , 32 ),
499- ("float8_e5m2" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
500- ("float8_e5m2" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
495+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497+ # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498+ # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
501501 ],
502502)
503503@pytest .mark .parametrize ("a_major" , ["k" ])
0 commit comments