Skip to content

Commit e47b434

Browse files
committed
swiglu fusion
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 52fb85e commit e47b434

File tree

9 files changed

+3205
-82
lines changed

9 files changed

+3205
-82
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 304 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Sm100BlockScaledContiguousGroupedGemmKernel
2727
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_finalize_fusion import \
2828
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
29+
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
30+
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
2931
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
3032
Sm100BlockScaledPersistentDenseGemmKernel
3133
from ..cute_dsl_kernels.blackwell.utils import make_ptr
@@ -439,7 +441,7 @@ def generate_permuted_idx_to_expanded_idx(
439441

440442
def inputs_pre_hook(self,
441443
inputs: List[torch.Tensor]) -> List[torch.Tensor]:
442-
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles = inputs
444+
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others = inputs
443445
num_tokens = self.infer_num_tokens(a.size(0))
444446
num_tokens_per_expert = self.generate_num_tokens_per_expert(
445447
num_tokens)
@@ -460,7 +462,7 @@ def inputs_pre_hook(self,
460462
[num_non_exiting_tiles_val],
461463
dtype=num_non_exiting_tiles.dtype,
462464
device=num_non_exiting_tiles.device)
463-
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles
465+
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, *others
464466

465467
def inputs_pre_hook_finalize_fusion(
466468
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -622,7 +624,7 @@ def forward(self, inputs: List[torch.Tensor],
622624
assert tile_idx_to_group_idx.dtype == torch.int32
623625
assert tile_idx_to_group_idx.size() == (num_tiles, )
624626
assert num_non_exiting_tiles.dtype == torch.int32
625-
assert num_non_exiting_tiles.size() == (1, )
627+
assert num_non_exiting_tiles.numel() == 1
626628

627629
c = torch.empty(m, n, dtype=self.output_dtype, device=a.device)
628630

@@ -899,7 +901,7 @@ def forward(self, inputs: List[torch.Tensor],
899901
assert permuted_idx_to_expanded_idx.dtype == torch.int32
900902
assert permuted_idx_to_expanded_idx.size() == (m, )
901903
assert num_non_exiting_tiles.dtype == torch.int32
902-
assert num_non_exiting_tiles.size() == (1, )
904+
assert num_non_exiting_tiles.numel() == 1
903905
assert token_final_scales.dtype == torch.float32
904906
assert token_final_scales.dim() == 2
905907
num_tokens = token_final_scales.size(0)
@@ -1090,6 +1092,304 @@ def _(
10901092
dtype=output_dtype,
10911093
device=input.device)
10921094

1095+
class Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner(
1096+
TunableRunner):
1097+
kernel_class = Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
1098+
kernel_cache = dict()
1099+
tuning_config_cache = dict()
1100+
1101+
def __init__(self,
1102+
num_experts: int,
1103+
top_k: int,
1104+
num_local_experts: int,
1105+
local_expert_offset: int,
1106+
tile_size: int,
1107+
scaling_vector_size: int = 16):
1108+
super().__init__()
1109+
self.num_experts = num_experts
1110+
self.top_k = top_k
1111+
self.num_local_experts = num_local_experts
1112+
self.local_expert_offset = local_expert_offset
1113+
self.tile_size = tile_size
1114+
self.scaling_vector_size = scaling_vector_size
1115+
1116+
if get_sm_version() != 100:
1117+
raise ValueError(
1118+
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
1119+
)
1120+
1121+
def get_valid_tactics(
1122+
self,
1123+
inputs: List[torch.Tensor],
1124+
profile: OptimizationProfile,
1125+
**kwargs,
1126+
) -> List[Tuple[int, int]]:
1127+
a, b, *_ = inputs
1128+
m, k = a.size(0), a.size(1) * 2
1129+
l, n = b.size(0), b.size(1)
1130+
1131+
# TODO: Add full shmoo
1132+
mma_tiler_mn_candidates = [(128, 128), (128, 256)]
1133+
cluster_shape_mn_candidates = [(1, 1), (1, 2)]
1134+
1135+
valid_tactics = []
1136+
for mma_tiler_mn, cluster_shape_mn in itertools.product(
1137+
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
1138+
if self.__class__.kernel_class.can_implement(
1139+
ab_dtype=cutlass.Float4E2M1FN,
1140+
sf_dtype=cutlass.Float8E4M3FN,
1141+
sf_vec_size=self.scaling_vector_size,
1142+
acc_dtype=cutlass.Float32,
1143+
c_dtype=cutlass.Float4E2M1FN,
1144+
use_2cta_instrs=False,
1145+
mma_tiler_mn=mma_tiler_mn,
1146+
cluster_shape_mn=cluster_shape_mn,
1147+
m=m,
1148+
n=n,
1149+
k=k,
1150+
l=l,
1151+
a_major="k",
1152+
b_major="k",
1153+
c_major="n",
1154+
m_aligned=self.tile_size,
1155+
):
1156+
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
1157+
1158+
return valid_tactics
1159+
1160+
def get_tuning_config(self) -> TuningConfig:
1161+
key = hash(self)
1162+
if key not in self.__class__.tuning_config_cache:
1163+
helper = GroupedGemmInputsHelper(self.num_experts, self.top_k,
1164+
self.num_local_experts,
1165+
self.local_expert_offset,
1166+
self.tile_size)
1167+
self.__class__.tuning_config_cache[key] = TuningConfig(
1168+
dynamic_tensor_specs=(DynamicTensorSpec(
1169+
0, 0, helper.gen_tuning_buckets,
1170+
helper.map_to_tuning_buckets), ),
1171+
constraint_specs=(ConstraintSpec(2, 0,
1172+
fp4_scale_infer_shape),
1173+
ConstraintSpec(
1174+
5, 0,
1175+
helper.infer_shape_max_num_tiles)),
1176+
inputs_pre_hook=helper.inputs_pre_hook,
1177+
)
1178+
return self.__class__.tuning_config_cache[key]
1179+
1180+
def forward(self, inputs: List[torch.Tensor],
1181+
tactic: Optional[tuple]) -> torch.Tensor:
1182+
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, num_non_exiting_tiles, global_sf = inputs
1183+
assert a.dtype == torch.float4_e2m1fn_x2
1184+
assert a.dim() == 2
1185+
assert b.dtype == torch.float4_e2m1fn_x2
1186+
assert b.dim() == 3
1187+
assert a_sf.dtype == torch.uint8
1188+
assert a_sf.dim() == 1
1189+
assert b_sf.dtype == torch.uint8
1190+
assert b_sf.dim() == 3
1191+
assert alpha.dtype == torch.float32
1192+
assert alpha.dim() == 1
1193+
1194+
m, k = a.size(0), a.size(1) * 2
1195+
l, n = b.size(0), b.size(1)
1196+
scale_k = k // self.scaling_vector_size
1197+
interm_size = n // 2
1198+
assert m % self.tile_size == 0
1199+
assert k % (self.scaling_vector_size * 4) == 0
1200+
assert n % (self.scaling_vector_size * 4 * 2) == 0
1201+
assert b.size(2) * 2 == k
1202+
assert a_sf.size(0) == m * scale_k
1203+
assert b_sf.size(0) == l
1204+
assert b_sf.size(1) == n
1205+
assert b_sf.size(2) == scale_k
1206+
assert alpha.size(0) == l
1207+
1208+
num_tiles = m // self.tile_size
1209+
assert tile_idx_to_group_idx.dtype == torch.int32
1210+
assert tile_idx_to_group_idx.size() == (num_tiles, )
1211+
assert num_non_exiting_tiles.dtype == torch.int32
1212+
assert num_non_exiting_tiles.numel() == 1
1213+
assert global_sf.dtype == torch.float32
1214+
assert global_sf.numel() == 1
1215+
1216+
c = torch.empty(m, interm_size // 2, dtype=a.dtype, device=a.device)
1217+
c_sf = torch.empty(m * interm_size // self.scaling_vector_size,
1218+
dtype=a_sf.dtype,
1219+
device=a_sf.device)
1220+
1221+
a_ptr = make_ptr(cutlass.Float4E2M1FN,
1222+
a.data_ptr(),
1223+
cute.AddressSpace.gmem,
1224+
assumed_align=32)
1225+
b_ptr = make_ptr(cutlass.Float4E2M1FN,
1226+
b.data_ptr(),
1227+
cute.AddressSpace.gmem,
1228+
assumed_align=32)
1229+
a_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1230+
a_sf.data_ptr(),
1231+
cute.AddressSpace.gmem,
1232+
assumed_align=16)
1233+
b_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1234+
b_sf.data_ptr(),
1235+
cute.AddressSpace.gmem,
1236+
assumed_align=16)
1237+
alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(),
1238+
cute.AddressSpace.gmem)
1239+
tile_idx_to_group_idx_ptr = make_ptr(
1240+
cutlass.Int32, tile_idx_to_group_idx.data_ptr(),
1241+
cute.AddressSpace.gmem)
1242+
num_non_exiting_tiles_ptr = make_ptr(
1243+
cutlass.Int32, num_non_exiting_tiles.data_ptr(),
1244+
cute.AddressSpace.gmem)
1245+
global_sf_ptr = make_ptr(cutlass.Float32, global_sf.data_ptr(),
1246+
cute.AddressSpace.gmem)
1247+
c_ptr = make_ptr(cutlass.Float4E2M1FN,
1248+
c.data_ptr(),
1249+
cute.AddressSpace.gmem,
1250+
assumed_align=32)
1251+
c_sf_ptr = make_ptr(cutlass.Float8E4M3FN,
1252+
c_sf.data_ptr(),
1253+
cute.AddressSpace.gmem,
1254+
assumed_align=16)
1255+
1256+
torch_stream = torch.cuda.current_stream()
1257+
stream = cuda.CUstream(torch_stream.cuda_stream)
1258+
1259+
if isinstance(tactic, tuple):
1260+
mma_tiler_mn, cluster_shape_mn = tactic
1261+
else:
1262+
mma_tiler_mn, cluster_shape_mn = (128, 128), (1, 1)
1263+
1264+
cache_key = (self.scaling_vector_size, self.tile_size, mma_tiler_mn,
1265+
cluster_shape_mn)
1266+
if cache_key not in self.__class__.kernel_cache:
1267+
gemm = self.__class__.kernel_class(
1268+
sf_vec_size=self.scaling_vector_size,
1269+
acc_dtype=cutlass.Float32,
1270+
use_2cta_instrs=False,
1271+
mma_tiler_mn=mma_tiler_mn,
1272+
cluster_shape_mn=cluster_shape_mn,
1273+
vectorized_f32=True,
1274+
)
1275+
# Compute max active clusters on current device
1276+
hardware_info = cutlass.utils.HardwareInfo()
1277+
max_active_clusters = hardware_info.get_max_active_clusters(
1278+
cluster_shape_mn[0] * cluster_shape_mn[1])
1279+
1280+
compiled_gemm = cute.compile(
1281+
gemm.wrapper,
1282+
a_ptr,
1283+
b_ptr,
1284+
a_sf_ptr,
1285+
b_sf_ptr,
1286+
c_ptr,
1287+
c_sf_ptr,
1288+
alpha_ptr,
1289+
tile_idx_to_group_idx_ptr,
1290+
num_non_exiting_tiles_ptr,
1291+
global_sf_ptr,
1292+
m,
1293+
n,
1294+
k,
1295+
l,
1296+
tile_size=self.tile_size,
1297+
scaling_vector_size=self.scaling_vector_size,
1298+
max_active_clusters=max_active_clusters,
1299+
stream=stream,
1300+
)
1301+
self.__class__.kernel_cache[cache_key] = compiled_gemm
1302+
else:
1303+
compiled_gemm = self.__class__.kernel_cache[cache_key]
1304+
1305+
compiled_gemm(
1306+
a_ptr,
1307+
b_ptr,
1308+
a_sf_ptr,
1309+
b_sf_ptr,
1310+
c_ptr,
1311+
c_sf_ptr,
1312+
alpha_ptr,
1313+
tile_idx_to_group_idx_ptr,
1314+
num_non_exiting_tiles_ptr,
1315+
global_sf_ptr,
1316+
m,
1317+
n,
1318+
k,
1319+
l,
1320+
stream=stream,
1321+
)
1322+
return c, c_sf
1323+
1324+
@torch.library.custom_op(
1325+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell",
1326+
mutates_args=(),
1327+
device_types="cuda")
1328+
def cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
1329+
input: torch.Tensor,
1330+
weight: torch.Tensor,
1331+
input_scale: torch.Tensor,
1332+
weight_scale: torch.Tensor,
1333+
alpha: torch.Tensor,
1334+
tile_idx_to_group_idx: torch.Tensor,
1335+
num_non_exiting_tiles: torch.Tensor,
1336+
global_sf: torch.Tensor,
1337+
num_experts: int,
1338+
top_k: int,
1339+
num_local_experts: int,
1340+
local_expert_offset: int,
1341+
tile_size: int,
1342+
scaling_vector_size: int = 16,
1343+
) -> Tuple[torch.Tensor, torch.Tensor]:
1344+
tuner = AutoTuner.get()
1345+
1346+
runner = Sm100BlockScaledContiguousGroupedGemmSwigluFusionRunner(
1347+
num_experts, top_k, num_local_experts, local_expert_offset,
1348+
tile_size, scaling_vector_size)
1349+
inputs = [
1350+
input, weight, input_scale, weight_scale, alpha,
1351+
tile_idx_to_group_idx, num_non_exiting_tiles, global_sf
1352+
]
1353+
1354+
_, best_tactic = tuner.choose_one(
1355+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell",
1356+
[runner],
1357+
runner.get_tuning_config(),
1358+
inputs,
1359+
)
1360+
output = runner(inputs, tactic=best_tactic)
1361+
return output
1362+
1363+
@torch.library.register_fake(
1364+
"trtllm::cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell")
1365+
def _(
1366+
input: torch.Tensor,
1367+
weight: torch.Tensor,
1368+
input_scale: torch.Tensor,
1369+
weight_scale: torch.Tensor,
1370+
alpha: torch.Tensor,
1371+
tile_idx_to_group_idx: torch.Tensor,
1372+
num_non_exiting_tiles: torch.Tensor,
1373+
global_sf: torch.Tensor,
1374+
num_experts: int,
1375+
top_k: int,
1376+
num_local_experts: int,
1377+
local_expert_offset: int,
1378+
tile_size: int,
1379+
scaling_vector_size: int = 16,
1380+
) -> Tuple[torch.Tensor, torch.Tensor]:
1381+
m = input.size(0)
1382+
n = weight.size(1)
1383+
interm_size = n // 2
1384+
output = torch.empty(m,
1385+
interm_size // 2,
1386+
dtype=input.dtype,
1387+
device=input.device)
1388+
output_scale = torch.empty(m * interm_size // scaling_vector_size,
1389+
dtype=input_scale.dtype,
1390+
device=input_scale.device)
1391+
return output, output_scale
1392+
10931393
class FusedMoEInputsHelper:
10941394

10951395
def __init__(self, num_experts: int, top_k: int, num_local_experts: int,

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
import cutlass.utils.blockscaled_layout as blockscaled_utils
5353
from cutlass.cute.nvgpu import cpasync, tcgen05
5454

55+
from .utils import is_power_of_2
56+
5557

5658
class Sm100BlockScaledContiguousGroupedGemmKernel:
5759
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
@@ -2052,9 +2054,6 @@ def is_valid_mma_tiler_and_cluster_shape(
20522054
is_valid = False
20532055

20542056
# Skip invalid cluster shape
2055-
def is_power_of_2(x: int) -> bool:
2056-
return x > 0 and (x & (x - 1)) == 0
2057-
20582057
if (
20592058
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
20602059
or cluster_shape_mn[0] <= 0

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from cutlass.cute.nvgpu import cpasync, tcgen05
4040
from cutlass.cutlass_dsl import T, dsl_user_op
4141

42+
from .utils import is_power_of_2
43+
4244
"""
4345
High-performance persistent blockscaled contiguous grouped dense GEMM (C = alpha * (SFA * A) * (SFB * B)) example for
4446
the NVIDIA Blackwell architecture using CUTE DSL.
@@ -2060,9 +2062,6 @@ def is_valid_mma_tiler_and_cluster_shape(
20602062
is_valid = False
20612063

20622064
# Skip invalid cluster shape
2063-
def is_power_of_2(x: int) -> bool:
2064-
return x > 0 and (x & (x - 1)) == 0
2065-
20662065
if (
20672066
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
20682067
or cluster_shape_mn[0] <= 0

0 commit comments

Comments
 (0)