2626 Sm100BlockScaledContiguousGroupedGemmKernel
2727 from ..cute_dsl_kernels .blackwell .blockscaled_contiguous_grouped_gemm_finalize_fusion import \
2828 Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
29+ from ..cute_dsl_kernels .blackwell .blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
30+ Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
2931 from ..cute_dsl_kernels .blackwell .dense_blockscaled_gemm_persistent import \
3032 Sm100BlockScaledPersistentDenseGemmKernel
3133 from ..cute_dsl_kernels .blackwell .utils import make_ptr
@@ -440,7 +442,7 @@ def generate_permuted_idx_to_expanded_idx(
440442
441443 def inputs_pre_hook (self ,
442444 inputs : List [torch .Tensor ]) -> List [torch .Tensor ]:
443- a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles = inputs
445+ a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , * others = inputs
444446 num_tokens = self .infer_num_tokens (a .size (0 ))
445447 num_tokens_per_expert = self .generate_num_tokens_per_expert (
446448 num_tokens )
@@ -461,7 +463,7 @@ def inputs_pre_hook(self,
461463 [num_non_exiting_tiles_val ],
462464 dtype = num_non_exiting_tiles .dtype ,
463465 device = num_non_exiting_tiles .device )
464- return a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles
466+ return a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , * others
465467
466468 def inputs_pre_hook_finalize_fusion (
467469 self , inputs : List [torch .Tensor ]) -> List [torch .Tensor ]:
@@ -623,7 +625,7 @@ def forward(self, inputs: List[torch.Tensor],
623625 assert tile_idx_to_group_idx .dtype == torch .int32
624626 assert tile_idx_to_group_idx .size () == (num_tiles , )
625627 assert num_non_exiting_tiles .dtype == torch .int32
626- assert num_non_exiting_tiles .size () == ( 1 , )
628+ assert num_non_exiting_tiles .numel () == 1
627629
628630 c = torch .empty (m , n , dtype = self .output_dtype , device = a .device )
629631
@@ -900,7 +902,7 @@ def forward(self, inputs: List[torch.Tensor],
900902 assert permuted_idx_to_expanded_idx .dtype == torch .int32
901903 assert permuted_idx_to_expanded_idx .size () == (m , )
902904 assert num_non_exiting_tiles .dtype == torch .int32
903- assert num_non_exiting_tiles .size () == ( 1 , )
905+ assert num_non_exiting_tiles .numel () == 1
904906 assert token_final_scales .dtype == torch .float32
905907 assert token_final_scales .dim () == 2
906908 num_tokens = token_final_scales .size (0 )
@@ -1091,6 +1093,304 @@ def _(
10911093 dtype = output_dtype ,
10921094 device = input .device )
10931095
1096+ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner (
1097+ TunableRunner ):
1098+ kernel_class = Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
1099+ kernel_cache = dict ()
1100+ tuning_config_cache = dict ()
1101+
1102+ def __init__ (self ,
1103+ num_experts : int ,
1104+ top_k : int ,
1105+ num_local_experts : int ,
1106+ local_expert_offset : int ,
1107+ tile_size : int ,
1108+ scaling_vector_size : int = 16 ):
1109+ super ().__init__ ()
1110+ self .num_experts = num_experts
1111+ self .top_k = top_k
1112+ self .num_local_experts = num_local_experts
1113+ self .local_expert_offset = local_expert_offset
1114+ self .tile_size = tile_size
1115+ self .scaling_vector_size = scaling_vector_size
1116+
1117+ if get_sm_version () != 100 :
1118+ raise ValueError (
1119+ f"SM version { get_sm_version ()} is not supported for { self .__class__ .__name__ } , it only supports SM 100"
1120+ )
1121+
1122+ def get_valid_tactics (
1123+ self ,
1124+ inputs : List [torch .Tensor ],
1125+ profile : OptimizationProfile ,
1126+ ** kwargs ,
1127+ ) -> List [Tuple [int , int ]]:
1128+ a , b , * _ = inputs
1129+ m , k = a .size (0 ), a .size (1 ) * 2
1130+ l , n = b .size (0 ), b .size (1 )
1131+
1132+ # TODO: Add full shmoo
1133+ mma_tiler_mn_candidates = [(128 , 128 ), (128 , 256 )]
1134+ cluster_shape_mn_candidates = [(1 , 1 ), (1 , 2 )]
1135+
1136+ valid_tactics = []
1137+ for mma_tiler_mn , cluster_shape_mn in itertools .product (
1138+ mma_tiler_mn_candidates , cluster_shape_mn_candidates ):
1139+ if self .__class__ .kernel_class .can_implement (
1140+ ab_dtype = cutlass .Float4E2M1FN ,
1141+ sf_dtype = cutlass .Float8E4M3FN ,
1142+ sf_vec_size = self .scaling_vector_size ,
1143+ acc_dtype = cutlass .Float32 ,
1144+ c_dtype = cutlass .Float4E2M1FN ,
1145+ use_2cta_instrs = False ,
1146+ mma_tiler_mn = mma_tiler_mn ,
1147+ cluster_shape_mn = cluster_shape_mn ,
1148+ m = m ,
1149+ n = n ,
1150+ k = k ,
1151+ l = l ,
1152+ a_major = "k" ,
1153+ b_major = "k" ,
1154+ c_major = "n" ,
1155+ m_aligned = self .tile_size ,
1156+ ):
1157+ valid_tactics .append ((mma_tiler_mn , cluster_shape_mn ))
1158+
1159+ return valid_tactics
1160+
1161+ def get_tuning_config (self ) -> TuningConfig :
1162+ key = hash (self )
1163+ if key not in self .__class__ .tuning_config_cache :
1164+ helper = GroupedGemmInputsHelper (self .num_experts , self .top_k ,
1165+ self .num_local_experts ,
1166+ self .local_expert_offset ,
1167+ self .tile_size )
1168+ self .__class__ .tuning_config_cache [key ] = TuningConfig (
1169+ dynamic_tensor_specs = (DynamicTensorSpec (
1170+ 0 , 0 , helper .gen_tuning_buckets ,
1171+ helper .map_to_tuning_buckets ), ),
1172+ constraint_specs = (ConstraintSpec (2 , 0 ,
1173+ fp4_scale_infer_shape ),
1174+ ConstraintSpec (
1175+ 5 , 0 ,
1176+ helper .infer_shape_max_num_tiles )),
1177+ inputs_pre_hook = helper .inputs_pre_hook ,
1178+ )
1179+ return self .__class__ .tuning_config_cache [key ]
1180+
1181+ def forward (self , inputs : List [torch .Tensor ],
1182+ tactic : Optional [tuple ]) -> torch .Tensor :
1183+ a , b , a_sf , b_sf , alpha , tile_idx_to_group_idx , num_non_exiting_tiles , global_sf = inputs
1184+ assert a .dtype == torch .float4_e2m1fn_x2
1185+ assert a .dim () == 2
1186+ assert b .dtype == torch .float4_e2m1fn_x2
1187+ assert b .dim () == 3
1188+ assert a_sf .dtype == torch .uint8
1189+ assert a_sf .dim () == 1
1190+ assert b_sf .dtype == torch .uint8
1191+ assert b_sf .dim () == 3
1192+ assert alpha .dtype == torch .float32
1193+ assert alpha .dim () == 1
1194+
1195+ m , k = a .size (0 ), a .size (1 ) * 2
1196+ l , n = b .size (0 ), b .size (1 )
1197+ scale_k = k // self .scaling_vector_size
1198+ interm_size = n // 2
1199+ assert m % self .tile_size == 0
1200+ assert k % (self .scaling_vector_size * 4 ) == 0
1201+ assert n % (self .scaling_vector_size * 4 * 2 ) == 0
1202+ assert b .size (2 ) * 2 == k
1203+ assert a_sf .size (0 ) == m * scale_k
1204+ assert b_sf .size (0 ) == l
1205+ assert b_sf .size (1 ) == n
1206+ assert b_sf .size (2 ) == scale_k
1207+ assert alpha .size (0 ) == l
1208+
1209+ num_tiles = m // self .tile_size
1210+ assert tile_idx_to_group_idx .dtype == torch .int32
1211+ assert tile_idx_to_group_idx .size () == (num_tiles , )
1212+ assert num_non_exiting_tiles .dtype == torch .int32
1213+ assert num_non_exiting_tiles .numel () == 1
1214+ assert global_sf .dtype == torch .float32
1215+ assert global_sf .numel () == 1
1216+
1217+ c = torch .empty (m , interm_size // 2 , dtype = a .dtype , device = a .device )
1218+ c_sf = torch .empty (m * interm_size // self .scaling_vector_size ,
1219+ dtype = a_sf .dtype ,
1220+ device = a_sf .device )
1221+
1222+ a_ptr = make_ptr (cutlass .Float4E2M1FN ,
1223+ a .data_ptr (),
1224+ cute .AddressSpace .gmem ,
1225+ assumed_align = 32 )
1226+ b_ptr = make_ptr (cutlass .Float4E2M1FN ,
1227+ b .data_ptr (),
1228+ cute .AddressSpace .gmem ,
1229+ assumed_align = 32 )
1230+ a_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1231+ a_sf .data_ptr (),
1232+ cute .AddressSpace .gmem ,
1233+ assumed_align = 16 )
1234+ b_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1235+ b_sf .data_ptr (),
1236+ cute .AddressSpace .gmem ,
1237+ assumed_align = 16 )
1238+ alpha_ptr = make_ptr (cutlass .Float32 , alpha .data_ptr (),
1239+ cute .AddressSpace .gmem )
1240+ tile_idx_to_group_idx_ptr = make_ptr (
1241+ cutlass .Int32 , tile_idx_to_group_idx .data_ptr (),
1242+ cute .AddressSpace .gmem )
1243+ num_non_exiting_tiles_ptr = make_ptr (
1244+ cutlass .Int32 , num_non_exiting_tiles .data_ptr (),
1245+ cute .AddressSpace .gmem )
1246+ global_sf_ptr = make_ptr (cutlass .Float32 , global_sf .data_ptr (),
1247+ cute .AddressSpace .gmem )
1248+ c_ptr = make_ptr (cutlass .Float4E2M1FN ,
1249+ c .data_ptr (),
1250+ cute .AddressSpace .gmem ,
1251+ assumed_align = 32 )
1252+ c_sf_ptr = make_ptr (cutlass .Float8E4M3FN ,
1253+ c_sf .data_ptr (),
1254+ cute .AddressSpace .gmem ,
1255+ assumed_align = 16 )
1256+
1257+ torch_stream = torch .cuda .current_stream ()
1258+ stream = cuda .CUstream (torch_stream .cuda_stream )
1259+
1260+ if isinstance (tactic , tuple ):
1261+ mma_tiler_mn , cluster_shape_mn = tactic
1262+ else :
1263+ mma_tiler_mn , cluster_shape_mn = (128 , 128 ), (1 , 1 )
1264+
1265+ cache_key = (self .scaling_vector_size , self .tile_size , mma_tiler_mn ,
1266+ cluster_shape_mn )
1267+ if cache_key not in self .__class__ .kernel_cache :
1268+ gemm = self .__class__ .kernel_class (
1269+ sf_vec_size = self .scaling_vector_size ,
1270+ acc_dtype = cutlass .Float32 ,
1271+ use_2cta_instrs = False ,
1272+ mma_tiler_mn = mma_tiler_mn ,
1273+ cluster_shape_mn = cluster_shape_mn ,
1274+ vectorized_f32 = True ,
1275+ )
1276+ # Compute max active clusters on current device
1277+ hardware_info = cutlass .utils .HardwareInfo ()
1278+ max_active_clusters = hardware_info .get_max_active_clusters (
1279+ cluster_shape_mn [0 ] * cluster_shape_mn [1 ])
1280+
1281+ compiled_gemm = cute .compile (
1282+ gemm .wrapper ,
1283+ a_ptr ,
1284+ b_ptr ,
1285+ a_sf_ptr ,
1286+ b_sf_ptr ,
1287+ c_ptr ,
1288+ c_sf_ptr ,
1289+ alpha_ptr ,
1290+ tile_idx_to_group_idx_ptr ,
1291+ num_non_exiting_tiles_ptr ,
1292+ global_sf_ptr ,
1293+ m ,
1294+ n ,
1295+ k ,
1296+ l ,
1297+ tile_size = self .tile_size ,
1298+ scaling_vector_size = self .scaling_vector_size ,
1299+ max_active_clusters = max_active_clusters ,
1300+ stream = stream ,
1301+ )
1302+ self .__class__ .kernel_cache [cache_key ] = compiled_gemm
1303+ else :
1304+ compiled_gemm = self .__class__ .kernel_cache [cache_key ]
1305+
1306+ compiled_gemm (
1307+ a_ptr ,
1308+ b_ptr ,
1309+ a_sf_ptr ,
1310+ b_sf_ptr ,
1311+ c_ptr ,
1312+ c_sf_ptr ,
1313+ alpha_ptr ,
1314+ tile_idx_to_group_idx_ptr ,
1315+ num_non_exiting_tiles_ptr ,
1316+ global_sf_ptr ,
1317+ m ,
1318+ n ,
1319+ k ,
1320+ l ,
1321+ stream = stream ,
1322+ )
1323+ return c , c_sf
1324+
1325+ @torch .library .custom_op (
1326+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" ,
1327+ mutates_args = (),
1328+ device_types = "cuda" )
1329+ def cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell (
1330+ input : torch .Tensor ,
1331+ weight : torch .Tensor ,
1332+ input_scale : torch .Tensor ,
1333+ weight_scale : torch .Tensor ,
1334+ alpha : torch .Tensor ,
1335+ tile_idx_to_group_idx : torch .Tensor ,
1336+ num_non_exiting_tiles : torch .Tensor ,
1337+ global_sf : torch .Tensor ,
1338+ num_experts : int ,
1339+ top_k : int ,
1340+ num_local_experts : int ,
1341+ local_expert_offset : int ,
1342+ tile_size : int ,
1343+ scaling_vector_size : int = 16 ,
1344+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1345+ tuner = AutoTuner .get ()
1346+
1347+ runner = Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner (
1348+ num_experts , top_k , num_local_experts , local_expert_offset ,
1349+ tile_size , scaling_vector_size )
1350+ inputs = [
1351+ input , weight , input_scale , weight_scale , alpha ,
1352+ tile_idx_to_group_idx , num_non_exiting_tiles , global_sf
1353+ ]
1354+
1355+ _ , best_tactic = tuner .choose_one (
1356+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" ,
1357+ [runner ],
1358+ runner .get_tuning_config (),
1359+ inputs ,
1360+ )
1361+ output = runner (inputs , tactic = best_tactic )
1362+ return output
1363+
1364+ @torch .library .register_fake (
1365+ "trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell" )
1366+ def _ (
1367+ input : torch .Tensor ,
1368+ weight : torch .Tensor ,
1369+ input_scale : torch .Tensor ,
1370+ weight_scale : torch .Tensor ,
1371+ alpha : torch .Tensor ,
1372+ tile_idx_to_group_idx : torch .Tensor ,
1373+ num_non_exiting_tiles : torch .Tensor ,
1374+ global_sf : torch .Tensor ,
1375+ num_experts : int ,
1376+ top_k : int ,
1377+ num_local_experts : int ,
1378+ local_expert_offset : int ,
1379+ tile_size : int ,
1380+ scaling_vector_size : int = 16 ,
1381+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1382+ m = input .size (0 )
1383+ n = weight .size (1 )
1384+ interm_size = n // 2
1385+ output = torch .empty (m ,
1386+ interm_size // 2 ,
1387+ dtype = input .dtype ,
1388+ device = input .device )
1389+ output_scale = torch .empty (m * interm_size // scaling_vector_size ,
1390+ dtype = input_scale .dtype ,
1391+ device = input_scale .device )
1392+ return output , output_scale
1393+
10941394 class FusedMoEInputsHelper :
10951395
10961396 def __init__ (self , num_experts : int , top_k : int , num_local_experts : int ,
0 commit comments