Skip to content

Commit 9ce5c52

Browse files
committed
swiglu fusion
Signed-off-by: Enwei Zhu <[email protected]>
1 parent a2c9f1c commit 9ce5c52

File tree

9 files changed

+3205
-82
lines changed

9 files changed

+3205
-82
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 304 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Sm100BlockScaledContiguousGroupedGemmKernel
2727
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_finalize_fusion import \
2828
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
29+
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
30+
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
2931
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
3032
Sm100BlockScaledPersistentDenseGemmKernel
3133
from ..cute_dsl_kernels.blackwell.utils import make_ptr
@@ -440,7 +442,7 @@ def generate_permuted_idx_to_expanded_idx(
440442

441443
def inputs_pre_hook(self,
442444
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
443-
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles = inputs
445+
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs
444446
num_tokens = self.infer_num_tokens(a.size(0))
445447
num_tokens_per_expert = self.generate_num_tokens_per_expert(
446448
num_tokens)
@@ -461,7 +463,7 @@ def inputs_pre_hook(self,
461463
[num_non_exiting_tiles_val],
462464
dtype=num_non_exiting_tiles.dtype,
463465
device=num_non_exiting_tiles.device)
464-
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles
466+
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others
465467

466468
def inputs_pre_hook_finalize_fusion(
467469
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -623,7 +625,7 @@ def forward(self, inputs: List[torch.Tensor],
623625
assert tile_idx_to_group_idx.dtype == torch.int32
624626
assert tile_idx_to_group_idx.size() == (num_tiles, )
625627
assert num_non_exiting_tiles.dtype == torch.int32
626-
assert num_non_exiting_tiles.size() == (1, )
628+
assert num_non_exiting_tiles.numel() == 1
627629

628630
c = torch.empty(m, n, dtype=self.output_dtype, device=a.device)
629631

@@ -900,7 +902,7 @@ def forward(self, inputs: List[torch.Tensor],
900902
assert permuted_idx_to_expanded_idx.dtype == torch.int32
901903
assert permuted_idx_to_expanded_idx.size() == (m, )
902904
assert num_non_exiting_tiles.dtype == torch.int32
903-
assert num_non_exiting_tiles.size() == (1, )
905+
assert num_non_exiting_tiles.numel() == 1
904906
assert token_final_scales.dtype == torch.float32
905907
assert token_final_scales.dim() == 2
906908
num_tokens = token_final_scales.size(0)
@@ -1091,6 +1093,304 @@ def _(
10911093
dtype=output_dtype,
10921094
device=input.device)
10931095

1096+
class Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner(
1097+
TunableRunner):
1098+
kernel_class = Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
1099+
kernel_cache = dict()
1100+
tuning_config_cache = dict()
1101+
1102+
def __init__(self,
1103+
num_experts: int,
1104+
top_k: int,
1105+
num_local_experts: int,
1106+
local_expert_offset: int,
1107+
tile_size: int,
1108+
scaling_vector_size: int = 16):
1109+
super().__init__()
1110+
self.num_experts = num_experts
1111+
self.top_k = top_k
1112+
self.num_local_experts = num_local_experts
1113+
self.local_expert_offset = local_expert_offset
1114+
self.tile_size = tile_size
1115+
self.scaling_vector_size = scaling_vector_size
1116+
1117+
if get_sm_version() != 100:
1118+
raise ValueError(
1119+
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
1120+
)
1121+
1122+
def get_valid_tactics(
1123+
self,
1124+
inputs: List[torch.Tensor],
1125+
profile: OptimizationProfile,
1126+
**kwargs,
1127+
) -> List[Tuple[int, int]]:
1128+
a, b, *_ = inputs
1129+
m, k = a.size(0), a.size(1) * 2
1130+
l, n = b.size(0), b.size(1)
1131+
1132+
# TODO: Add full shmoo
1133+
mma_tiler_mn_candidates = [(128, 128), (128, 256)]
1134+
cluster_shape_mn_candidates = [(1, 1), (1, 2)]
1135+
1136+
valid_tactics = []
1137+
for mma_tiler_mn, cluster_shape_mn in itertools.product(
1138+
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
1139+
if self.__class__.kernel_class.can_implement(
1140+
ab_dtype=cutlass.Float4E2M1FN,
1141+
sf_dtype=cutlass.Float8E4M3FN,
1142+
sf_vec_size=self.scaling_vector_size,
1143+
acc_dtype=cutlass.Float32,
1144+
c_dtype=cutlass.Float4E2M1FN,
1145+
use_2cta_instrs=False,
1146+
mma_tiler_mn=mma_tiler_mn,
1147+
cluster_shape_mn=cluster_shape_mn,
1148+
m=m,
1149+
n=n,
1150+
k=k,
1151+
l=l,
1152+
a_major="k",
1153+
b_major="k",
1154+
c_major="n",
1155+
m_aligned=self.tile_size,
1156+
):
1157+
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
1158+
1159+
return valid_tactics
1160+
1161+
def get_tuning_config(self) -> TuningConfig:
1162+
key = hash(self)
1163+
if key not in self.__class__.tuning_config_cache:
1164+
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
1165+
self.num_local_experts,
1166+
self.local_expert_offset,
1167+
self.tile_size)
1168+
self.__class__.tuning_config_cache[key] = TuningConfig(
1169+
dynamic_tensor_specs=(DynamicTensorSpec(
1170+
0, 0, helper.gen_tuning_buckets,
1171+
helper.map_to_tuning_buckets), ),
1172+
constraint_specs=(ConstraintSpec(2, 0,
1173+
fp4_scale_infer_shape),
1174+
ConstraintSpec(
1175+
5, 0,
1176+
helper.infer_shape_max_num_tiles)),
1177+
inputs_pre_hook=helper.inputs_pre_hook,
1178+
)
1179+
return self.__class__.tuning_config_cache[key]
1180+
1181+
def forward(self, inputs: List[torch.Tensor],
1182+
tactic: Optional[tuple]) -> torch.Tensor:
1183+
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, global_sf = inputs
1184+
assert a.dtype == torch.float4_e2m1fn_x2
1185+
assert a.dim() == 2
1186+
assert b.dtype == torch.float4_e2m1fn_x2
1187+
assert b.dim() == 3
1188+
assert a_sf.dtype == torch.uint8
1189+
assert a_sf.dim() == 1
1190+
assert b_sf.dtype == torch.uint8
1191+
assert b_sf.dim() == 3
1192+
assert alpha.dtype == torch.float32
1193+
assert alpha.dim() == 1
1194+
1195+
m, k = a.size(0), a.size(1) * 2
1196+
l, n = b.size(0), b.size(1)
1197+
scale_k = k // self.scaling_vector_size
1198+
interm_size = n // 2
1199+
assert m % self.tile_size == 0
1200+
assert k % (self.scaling_vector_size * 4) == 0
1201+
assert n % (self.scaling_vector_size * 4 * 2) == 0
1202+
assert b.size(2) * 2 == k
1203+
assert a_sf.size(0) == m * scale_k
1204+
assert b_sf.size(0) == l
1205+
assert b_sf.size(1) == n
1206+
assert b_sf.size(2) == scale_k
1207+
assert alpha.size(0) == l
1208+
1209+
num_tiles = m // self.tile_size
1210+
assert tile_idx_to_group_idx.dtype == torch.int32
1211+
assert tile_idx_to_group_idx.size() == (num_tiles, )
1212+
assert num_non_exiting_tiles.dtype == torch.int32
1213+
assert num_non_exiting_tiles.numel() == 1
1214+
assert global_sf.dtype == torch.float32
1215+
assert global_sf.numel() == 1
1216+
1217+
c = torch.empty(m, interm_size // 2, dtype=a.dtype, device=a.device)
1218+
c_sf = torch.empty(m * interm_size // self.scaling_vector_size,
1219+
dtype=a_sf.dtype,
1220+
device=a_sf.device)
1221+
1222+
a_ptr = make_ptr(cutlass.Float4E2M1FN,
1223+
a.data_ptr(),
1224+
cute.AddressSpace.gmem,
1225+
assumed_align=32)
1226+
b_ptr = make_ptr(cutlass.Float4E2M1FN,
1227+
b.data_ptr(),
1228+
cute.AddressSpace.gmem,
1229+
assumed_align=32)
1230+
a_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1231+
a_sf.data_ptr(),
1232+
cute.AddressSpace.gmem,
1233+
assumed_align=16)
1234+
b_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1235+
b_sf.data_ptr(),
1236+
cute.AddressSpace.gmem,
1237+
assumed_align=16)
1238+
alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(),
1239+
cute.AddressSpace.gmem)
1240+
tile_idx_to_group_idx_ptr = make_ptr(
1241+
cutlass.Int32, tile_idx_to_group_idx.data_ptr(),
1242+
cute.AddressSpace.gmem)
1243+
num_non_exiting_tiles_ptr = make_ptr(
1244+
cutlass.Int32, num_non_exiting_tiles.data_ptr(),
1245+
cute.AddressSpace.gmem)
1246+
global_sf_ptr = make_ptr(cutlass.Float32, global_sf.data_ptr(),
1247+
cute.AddressSpace.gmem)
1248+
c_ptr = make_ptr(cutlass.Float4E2M1FN,
1249+
c.data_ptr(),
1250+
cute.AddressSpace.gmem,
1251+
assumed_align=32)
1252+
c_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1253+
c_sf.data_ptr(),
1254+
cute.AddressSpace.gmem,
1255+
assumed_align=16)
1256+
1257+
torch_stream = torch.cuda.current_stream()
1258+
stream = cuda.CUstream(torch_stream.cuda_stream)
1259+
1260+
if isinstance(tactic, tuple):
1261+
mma_tiler_mn, cluster_shape_mn = tactic
1262+
else:
1263+
mma_tiler_mn, cluster_shape_mn = (128, 128), (1, 1)
1264+
1265+
cache_key = (self.scaling_vector_size, self.tile_size, mma_tiler_mn,
1266+
cluster_shape_mn)
1267+
if cache_key not in self.__class__.kernel_cache:
1268+
gemm = self.__class__.kernel_class(
1269+
sf_vec_size=self.scaling_vector_size,
1270+
acc_dtype=cutlass.Float32,
1271+
use_2cta_instrs=False,
1272+
mma_tiler_mn=mma_tiler_mn,
1273+
cluster_shape_mn=cluster_shape_mn,
1274+
vectorized_f32=True,
1275+
)
1276+
# Compute max active clusters on current device
1277+
hardware_info = cutlass.utils.HardwareInfo()
1278+
max_active_clusters = hardware_info.get_max_active_clusters(
1279+
cluster_shape_mn[0] * cluster_shape_mn[1])
1280+
1281+
compiled_gemm = cute.compile(
1282+
gemm.wrapper,
1283+
a_ptr,
1284+
b_ptr,
1285+
a_sf_ptr,
1286+
b_sf_ptr,
1287+
c_ptr,
1288+
c_sf_ptr,
1289+
alpha_ptr,
1290+
tile_idx_to_group_idx_ptr,
1291+
num_non_exiting_tiles_ptr,
1292+
global_sf_ptr,
1293+
m,
1294+
n,
1295+
k,
1296+
l,
1297+
tile_size=self.tile_size,
1298+
scaling_vector_size=self.scaling_vector_size,
1299+
max_active_clusters=max_active_clusters,
1300+
stream=stream,
1301+
)
1302+
self.__class__.kernel_cache[cache_key] = compiled_gemm
1303+
else:
1304+
compiled_gemm = self.__class__.kernel_cache[cache_key]
1305+
1306+
compiled_gemm(
1307+
a_ptr,
1308+
b_ptr,
1309+
a_sf_ptr,
1310+
b_sf_ptr,
1311+
c_ptr,
1312+
c_sf_ptr,
1313+
alpha_ptr,
1314+
tile_idx_to_group_idx_ptr,
1315+
num_non_exiting_tiles_ptr,
1316+
global_sf_ptr,
1317+
m,
1318+
n,
1319+
k,
1320+
l,
1321+
stream=stream,
1322+
)
1323+
return c, c_sf
1324+
1325+
@torch.library.custom_op(
1326+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell",
1327+
mutates_args=(),
1328+
device_types="cuda")
1329+
def cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
1330+
input: torch.Tensor,
1331+
weight: torch.Tensor,
1332+
input_scale: torch.Tensor,
1333+
weight_scale: torch.Tensor,
1334+
alpha: torch.Tensor,
1335+
tile_idx_to_group_idx: torch.Tensor,
1336+
num_non_exiting_tiles: torch.Tensor,
1337+
global_sf: torch.Tensor,
1338+
num_experts: int,
1339+
top_k: int,
1340+
num_local_experts: int,
1341+
local_expert_offset: int,
1342+
tile_size: int,
1343+
scaling_vector_size: int = 16,
1344+
) -> Tuple[torch.Tensor, torch.Tensor]:
1345+
tuner = AutoTuner.get()
1346+
1347+
runner = Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner(
1348+
num_experts, top_k, num_local_experts, local_expert_offset,
1349+
tile_size, scaling_vector_size)
1350+
inputs = [
1351+
input, weight, input_scale, weight_scale, alpha,
1352+
tile_idx_to_group_idx, num_non_exiting_tiles, global_sf
1353+
]
1354+
1355+
_, best_tactic = tuner.choose_one(
1356+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell",
1357+
[runner],
1358+
runner.get_tuning_config(),
1359+
inputs,
1360+
)
1361+
output = runner(inputs, tactic=best_tactic)
1362+
return output
1363+
1364+
@torch.library.register_fake(
1365+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell")
1366+
def _(
1367+
input: torch.Tensor,
1368+
weight: torch.Tensor,
1369+
input_scale: torch.Tensor,
1370+
weight_scale: torch.Tensor,
1371+
alpha: torch.Tensor,
1372+
tile_idx_to_group_idx: torch.Tensor,
1373+
num_non_exiting_tiles: torch.Tensor,
1374+
global_sf: torch.Tensor,
1375+
num_experts: int,
1376+
top_k: int,
1377+
num_local_experts: int,
1378+
local_expert_offset: int,
1379+
tile_size: int,
1380+
scaling_vector_size: int = 16,
1381+
) -> Tuple[torch.Tensor, torch.Tensor]:
1382+
m = input.size(0)
1383+
n = weight.size(1)
1384+
interm_size = n // 2
1385+
output = torch.empty(m,
1386+
interm_size // 2,
1387+
dtype=input.dtype,
1388+
device=input.device)
1389+
output_scale = torch.empty(m * interm_size // scaling_vector_size,
1390+
dtype=input_scale.dtype,
1391+
device=input_scale.device)
1392+
return output, output_scale
1393+
10941394
class FusedMoEInputsHelper:
10951395

10961396
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
import cutlass.utils.blockscaled_layout as blockscaled_utils
5353
from cutlass.cute.nvgpu import cpasync, tcgen05
5454

55+
from .utils import is_power_of_2
56+
5557

5658
class Sm100BlockScaledContiguousGroupedGemmKernel:
5759
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
@@ -2052,9 +2054,6 @@ def is_valid_mma_tiler_and_cluster_shape(
20522054
is_valid = False
20532055

20542056
# Skip invalid cluster shape
2055-
def is_power_of_2(x: int) -> bool:
2056-
return x > 0 and (x & (x - 1)) == 0
2057-
20582057
if (
20592058
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
20602059
or cluster_shape_mn[0] <= 0

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from cutlass.cute.nvgpu import cpasync, tcgen05
4040
from cutlass.cutlass_dsl import T, dsl_user_op
4141

42+
from .utils import is_power_of_2
43+
4244
"""
4345
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
4446
the NVIDIA Blackwell architecture using CUTE DSL.
@@ -2060,9 +2062,6 @@ def is_valid_mma_tiler_and_cluster_shape(
20602062
is_valid = False
20612063

20622064
# Skip invalid cluster shape
2063-
def is_power_of_2(x: int) -> bool:
2064-
return x > 0 and (x & (x - 1)) == 0
2065-
20662065
if (
20672066
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
20682067
or cluster_shape_mn[0] <= 0

0 commit comments

Comments
 (0)