2626 Sm100BlockScaledContiguousGroupedGemmKernel
2727 from ..cute_dsl_kernels .blackwell .blockscaled_contiguous_grouped_gemm_finalize_fusion import \
2828 Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
29+ from ..cute_dsl_kernels .blackwell .blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
30+ Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
2931 from ..cute_dsl_kernels .blackwell .dense_blockscaled_gemm_persistent import \
3032 Sm100BlockScaledPersistentDenseGemmKernel
3133 from ..cute_dsl_kernels .blackwell .utils import make_ptr
@@ -439,7 +441,7 @@ def generate_permuted_idx_to_expanded_idx(
439441
440442 def inputs_pre_hook (self ,
441443 inputs : List [torch .Tensor ]) -> List [torch .Tensor ]:
442- a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles = inputs
444+ a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , * others = inputs
443445 num_tokens = self .infer_num_tokens (a .size (0 ))
444446 num_tokens_per_expert = self .generate_num_tokens_per_expert (
445447 num_tokens )
@@ -460,7 +462,7 @@ def inputs_pre_hook(self,
460462 [num_non_exiting_tiles_val ],
461463 dtype = num_non_exiting_tiles .dtype ,
462464 device = num_non_exiting_tiles .device )
463- return a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles
465+ return a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , * others
464466
465467 def inputs_pre_hook_finalize_fusion (
466468 self , inputs : List [torch .Tensor ]) -> List [torch .Tensor ]:
@@ -622,7 +624,7 @@ def forward(self, inputs: List[torch.Tensor],
622624 assert tile_idx_to_group_idx .dtype == torch .int32
623625 assert tile_idx_to_group_idx .size () == (num_tiles , )
624626 assert num_non_exiting_tiles .dtype == torch .int32
625- assert num_non_exiting_tiles .size () == ( 1 , )
627+ assert num_non_exiting_tiles .numel () == 1
626628
627629 c = torch .empty (m , n , dtype = self .output_dtype , device = a .device )
628630
@@ -899,7 +901,7 @@ def forward(self, inputs: List[torch.Tensor],
899901 assert permuted_idx_to_expanded_idx .dtype == torch .int32
900902 assert permuted_idx_to_expanded_idx .size () == (m , )
901903 assert num_non_exiting_tiles .dtype == torch .int32
902- assert num_non_exiting_tiles .size () == ( 1 , )
904+ assert num_non_exiting_tiles .numel () == 1
903905 assert token_final_scales .dtype == torch .float32
904906 assert token_final_scales .dim () == 2
905907 num_tokens = token_final_scales .size (0 )
@@ -1090,6 +1092,304 @@ def _(
10901092 dtype = output_dtype ,
10911093 device = input .device )
10921094
1095+ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner (
1096+ TunableRunner ):
1097+ kernel_class = Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
1098+ kernel_cache = dict ()
1099+ tuning_config_cache = dict ()
1100+
1101+ def __init__ (self ,
1102+ num_experts : int ,
1103+ top_k : int ,
1104+ num_local_experts : int ,
1105+ local_expert_offset : int ,
1106+ tile_size : int ,
1107+ scaling_vector_size : int = 16 ):
1108+ super ().__init__ ()
1109+ self .num_experts = num_experts
1110+ self .top_k = top_k
1111+ self .num_local_experts = num_local_experts
1112+ self .local_expert_offset = local_expert_offset
1113+ self .tile_size = tile_size
1114+ self .scaling_vector_size = scaling_vector_size
1115+
1116+ if get_sm_version () != 100 :
1117+ raise ValueError (
1118+ f"SM version { get_sm_version ()} is not supported for { self .__class__ .__name__ } , it only supports SM 100"
1119+ )
1120+
1121+ def get_valid_tactics (
1122+ self ,
1123+ inputs : List [torch .Tensor ],
1124+ profile : OptimizationProfile ,
1125+ ** kwargs ,
1126+ ) -> List [Tuple [int , int ]]:
1127+ a , b , * _ = inputs
1128+ m , k = a .size (0 ), a .size (1 ) * 2
1129+ l , n = b .size (0 ), b .size (1 )
1130+
1131+ # TODO: Add full shmoo
1132+ mma_tiler_mn_candidates = [(128 , 128 ), (128 , 256 )]
1133+ cluster_shape_mn_candidates = [(1 , 1 ), (1 , 2 )]
1134+
1135+ valid_tactics = []
1136+ for mma_tiler_mn , cluster_shape_mn in itertools .product (
1137+ mma_tiler_mn_candidates , cluster_shape_mn_candidates ):
1138+ if self .__class__ .kernel_class .can_implement (
1139+ ab_dtype = cutlass .Float4E2M1FN ,
1140+ sf_dtype = cutlass .Float8E4M3FN ,
1141+ sf_vec_size = self .scaling_vector_size ,
1142+ acc_dtype = cutlass .Float32 ,
1143+ c_dtype = cutlass .Float4E2M1FN ,
1144+ use_2cta_instrs = False ,
1145+ mma_tiler_mn = mma_tiler_mn ,
1146+ cluster_shape_mn = cluster_shape_mn ,
1147+ m = m ,
1148+ n = n ,
1149+ k = k ,
1150+ l = l ,
1151+ a_major = "k" ,
1152+ b_major = "k" ,
1153+ c_major = "n" ,
1154+ m_aligned = self .tile_size ,
1155+ ):
1156+ valid_tactics .append ((mma_tiler_mn , cluster_shape_mn ))
1157+
1158+ return valid_tactics
1159+
1160+ def get_tuning_config (self ) -> TuningConfig :
1161+ key = hash (self )
1162+ if key not in self .__class__ .tuning_config_cache :
1163+ helper = GroupedGemmInputsHelper (self .num_experts , self .top_k ,
1164+ self .num_local_experts ,
1165+ self .local_expert_offset ,
1166+ self .tile_size )
1167+ self .__class__ .tuning_config_cache [key ] = TuningConfig (
1168+ dynamic_tensor_specs = (DynamicTensorSpec (
1169+ 0 , 0 , helper .gen_tuning_buckets ,
1170+ helper .map_to_tuning_buckets ), ),
1171+ constraint_specs = (ConstraintSpec (2 , 0 ,
1172+ fp4_scale_infer_shape ),
1173+ ConstraintSpec (
1174+ 5 , 0 ,
1175+ helper .infer_shape_max_num_tiles )),
1176+ inputs_pre_hook = helper .inputs_pre_hook ,
1177+ )
1178+ return self .__class__ .tuning_config_cache [key ]
1179+
1180+ def forward (self , inputs : List [torch .Tensor ],
1181+ tactic : Optional [tuple ]) -> torch .Tensor :
1182+ a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , global_sf = inputs
1183+ assert a .dtype == torch .float4_e2m1fn_x2
1184+ assert a .dim () == 2
1185+ assert b .dtype == torch .float4_e2m1fn_x2
1186+ assert b .dim () == 3
1187+ assert a_sf .dtype == torch .uint8
1188+ assert a_sf .dim () == 1
1189+ assert b_sf .dtype == torch .uint8
1190+ assert b_sf .dim () == 3
1191+ assert alpha .dtype == torch .float32
1192+ assert alpha .dim () == 1
1193+
1194+ m , k = a .size (0 ), a .size (1 ) * 2
1195+ l , n = b .size (0 ), b .size (1 )
1196+ scale_k = k // self .scaling_vector_size
1197+ interm_size = n // 2
1198+ assert m % self .tile_size == 0
1199+ assert k % (self .scaling_vector_size * 4 ) == 0
1200+ assert n % (self .scaling_vector_size * 4 * 2 ) == 0
1201+ assert b .size (2 ) * 2 == k
1202+ assert a_sf .size (0 ) == m * scale_k
1203+ assert b_sf .size (0 ) == l
1204+ assert b_sf .size (1 ) == n
1205+ assert b_sf .size (2 ) == scale_k
1206+ assert alpha .size (0 ) == l
1207+
1208+ num_tiles = m // self .tile_size
1209+ assert tile_idx_to_group_idx .dtype == torch .int32
1210+ assert tile_idx_to_group_idx .size () == (num_tiles , )
1211+ assert num_non_exiting_tiles .dtype == torch .int32
1212+ assert num_non_exiting_tiles .numel () == 1
1213+ assert global_sf .dtype == torch .float32
1214+ assert global_sf .numel () == 1
1215+
1216+ c = torch .empty (m , interm_size // 2 , dtype = a .dtype , device = a .device )
1217+ c_sf = torch .empty (m * interm_size // self .scaling_vector_size ,
1218+ dtype = a_sf .dtype ,
1219+ device = a_sf .device )
1220+
1221+ a_ptr = make_ptr (cutlass .Float4E2M1FN ,
1222+ a .data_ptr (),
1223+ cute .AddressSpace .gmem ,
1224+ assumed_align = 32 )
1225+ b_ptr = make_ptr (cutlass .Float4E2M1FN ,
1226+ b .data_ptr (),
1227+ cute .AddressSpace .gmem ,
1228+ assumed_align = 32 )
1229+ a_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1230+ a_sf .data_ptr (),
1231+ cute .AddressSpace .gmem ,
1232+ assumed_align = 16 )
1233+ b_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1234+ b_sf .data_ptr (),
1235+ cute .AddressSpace .gmem ,
1236+ assumed_align = 16 )
1237+ alpha_ptr = make_ptr (cutlass .Float32 , alpha .data_ptr (),
1238+ cute .AddressSpace .gmem )
1239+ tile_idx_to_group_idx_ptr = make_ptr (
1240+ cutlass .Int32 , tile_idx_to_group_idx .data_ptr (),
1241+ cute .AddressSpace .gmem )
1242+ num_non_exiting_tiles_ptr = make_ptr (
1243+ cutlass .Int32 , num_non_exiting_tiles .data_ptr (),
1244+ cute .AddressSpace .gmem )
1245+ global_sf_ptr = make_ptr (cutlass .Float32 , global_sf .data_ptr (),
1246+ cute .AddressSpace .gmem )
1247+ c_ptr = make_ptr (cutlass .Float4E2M1FN ,
1248+ c .data_ptr (),
1249+ cute .AddressSpace .gmem ,
1250+ assumed_align = 32 )
1251+ c_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1252+ c_sf .data_ptr (),
1253+ cute .AddressSpace .gmem ,
1254+ assumed_align = 16 )
1255+
1256+ torch_stream = torch .cuda .current_stream ()
1257+ stream = cuda .CUstream (torch_stream .cuda_stream )
1258+
1259+ if isinstance (tactic , tuple ):
1260+ mma_tiler_mn , cluster_shape_mn = tactic
1261+ else :
1262+ mma_tiler_mn , cluster_shape_mn = (128 , 128 ), (1 , 1 )
1263+
1264+ cache_key = (self .scaling_vector_size , self .tile_size , mma_tiler_mn ,
1265+ cluster_shape_mn )
1266+ if cache_key not in self .__class__ .kernel_cache :
1267+ gemm = self .__class__ .kernel_class (
1268+ sf_vec_size = self .scaling_vector_size ,
1269+ acc_dtype = cutlass .Float32 ,
1270+ use_2cta_instrs = False ,
1271+ mma_tiler_mn = mma_tiler_mn ,
1272+ cluster_shape_mn = cluster_shape_mn ,
1273+ vectorized_f32 = True ,
1274+ )
1275+ # Compute max active clusters on current device
1276+ hardware_info = cutlass .utils .HardwareInfo ()
1277+ max_active_clusters = hardware_info .get_max_active_clusters (
1278+ cluster_shape_mn [0 ] * cluster_shape_mn [1 ])
1279+
1280+ compiled_gemm = cute .compile (
1281+ gemm .wrapper ,
1282+ a_ptr ,
1283+ b_ptr ,
1284+ a_sf_ptr ,
1285+ b_sf_ptr ,
1286+ c_ptr ,
1287+ c_sf_ptr ,
1288+ alpha_ptr ,
1289+ tile_idx_to_group_idx_ptr ,
1290+ num_non_exiting_tiles_ptr ,
1291+ global_sf_ptr ,
1292+ m ,
1293+ n ,
1294+ k ,
1295+ l ,
1296+ tile_size = self .tile_size ,
1297+ scaling_vector_size = self .scaling_vector_size ,
1298+ max_active_clusters = max_active_clusters ,
1299+ stream = stream ,
1300+ )
1301+ self .__class__ .kernel_cache [cache_key ] = compiled_gemm
1302+ else :
1303+ compiled_gemm = self .__class__ .kernel_cache [cache_key ]
1304+
1305+ compiled_gemm (
1306+ a_ptr ,
1307+ b_ptr ,
1308+ a_sf_ptr ,
1309+ b_sf_ptr ,
1310+ c_ptr ,
1311+ c_sf_ptr ,
1312+ alpha_ptr ,
1313+ tile_idx_to_group_idx_ptr ,
1314+ num_non_exiting_tiles_ptr ,
1315+ global_sf_ptr ,
1316+ m ,
1317+ n ,
1318+ k ,
1319+ l ,
1320+ stream = stream ,
1321+ )
1322+ return c , c_sf
1323+
1324+ @torch .library .custom_op (
1325+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" ,
1326+ mutates_args = (),
1327+ device_types = "cuda" )
1328+ def cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell (
1329+ input : torch .Tensor ,
1330+ weight : torch .Tensor ,
1331+ input_scale : torch .Tensor ,
1332+ weight_scale : torch .Tensor ,
1333+ alpha : torch .Tensor ,
1334+ tile_idx_to_group_idx : torch .Tensor ,
1335+ num_non_exiting_tiles : torch .Tensor ,
1336+ global_sf : torch .Tensor ,
1337+ num_experts : int ,
1338+ top_k : int ,
1339+ num_local_experts : int ,
1340+ local_expert_offset : int ,
1341+ tile_size : int ,
1342+ scaling_vector_size : int = 16 ,
1343+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1344+ tuner = AutoTuner .get ()
1345+
1346+ runner = Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner (
1347+ num_experts , top_k , num_local_experts , local_expert_offset ,
1348+ tile_size , scaling_vector_size )
1349+ inputs = [
1350+ input , weight , input_scale , weight_scale , alpha ,
1351+ tile_idx_to_group_idx , num_non_exiting_tiles , global_sf
1352+ ]
1353+
1354+ _ , best_tactic = tuner .choose_one (
1355+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" ,
1356+ [runner ],
1357+ runner .get_tuning_config (),
1358+ inputs ,
1359+ )
1360+ output = runner (inputs , tactic = best_tactic )
1361+ return output
1362+
1363+ @torch .library .register_fake (
1364+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" )
1365+ def _ (
1366+ input : torch .Tensor ,
1367+ weight : torch .Tensor ,
1368+ input_scale : torch .Tensor ,
1369+ weight_scale : torch .Tensor ,
1370+ alpha : torch .Tensor ,
1371+ tile_idx_to_group_idx : torch .Tensor ,
1372+ num_non_exiting_tiles : torch .Tensor ,
1373+ global_sf : torch .Tensor ,
1374+ num_experts : int ,
1375+ top_k : int ,
1376+ num_local_experts : int ,
1377+ local_expert_offset : int ,
1378+ tile_size : int ,
1379+ scaling_vector_size : int = 16 ,
1380+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1381+ m = input .size (0 )
1382+ n = weight .size (1 )
1383+ interm_size = n // 2
1384+ output = torch .empty (m ,
1385+ interm_size // 2 ,
1386+ dtype = input .dtype ,
1387+ device = input .device )
1388+ output_scale = torch .empty (m * interm_size // scaling_vector_size ,
1389+ dtype = input_scale .dtype ,
1390+ device = input_scale .device )
1391+ return output , output_scale
1392+
10931393 class FusedMoEInputsHelper :
10941394
10951395 def __init__ (self , num_experts : int , top_k : int , num_local_experts : int ,
0 commit comments