Skip to content

Commit 52fb85e

Browse files
committed
update finalize fusion
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 6ef1dda commit 52fb85e

File tree

4 files changed

+42
-51
lines changed

4 files changed

+42
-51
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def __init__(self, num_experts: int, top_k: int, num_local_experts: int,
339339
self.num_local_experts = num_local_experts
340340
self.local_expert_offset = local_expert_offset
341341
self.tile_size = tile_size
342+
# Padding values should never be accessed.
343+
# Intentionally use a large padding value to expose issues early.
344+
self.pad_val = int(2e9)
342345

343346
def get_max_num_tiles(self, num_tokens: int) -> int:
344347
num_expanded_tokens = num_tokens * self.top_k
@@ -431,8 +434,7 @@ def generate_permuted_idx_to_expanded_idx(
431434
permuted_idx_to_expanded_idx.append(expanded_idx)
432435
colmajor_expanded_idx += 1
433436
else:
434-
# TODO: Remove this WAR.
435-
permuted_idx_to_expanded_idx.append(-1)
437+
permuted_idx_to_expanded_idx.append(self.pad_val)
436438
return permuted_idx_to_expanded_idx
437439

438440
def inputs_pre_hook(self,
@@ -450,7 +452,8 @@ def inputs_pre_hook(self,
450452
assert num_padding_tiles_val >= 0
451453

452454
tile_idx_to_group_idx = torch.tensor(
453-
tile_idx_to_group_idx_list + [int(1e9)] * num_padding_tiles_val,
455+
tile_idx_to_group_idx_list +
456+
[self.pad_val] * num_padding_tiles_val,
454457
dtype=tile_idx_to_group_idx.dtype,
455458
device=tile_idx_to_group_idx.device)
456459
num_non_exiting_tiles = torch.tensor(
@@ -481,15 +484,17 @@ def inputs_pre_hook_finalize_fusion(
481484
) == num_non_exiting_tiles_val * self.tile_size
482485

483486
tile_idx_to_group_idx = torch.tensor(
484-
tile_idx_to_group_idx_list + [int(1e9)] * num_padding_tiles_val,
487+
tile_idx_to_group_idx_list +
488+
[self.pad_val] * num_padding_tiles_val,
485489
dtype=tile_idx_to_group_idx.dtype,
486490
device=tile_idx_to_group_idx.device)
487491
tile_idx_to_mn_limit = torch.tensor(
488-
tile_idx_to_mn_limit_list + [int(1e9)] * num_padding_tiles_val,
492+
tile_idx_to_mn_limit_list +
493+
[self.pad_val] * num_padding_tiles_val,
489494
dtype=tile_idx_to_mn_limit.dtype,
490495
device=tile_idx_to_mn_limit.device)
491496
permuted_idx_to_expanded_idx = torch.tensor(
492-
permuted_idx_to_expanded_idx_list + [int(1e9)] *
497+
permuted_idx_to_expanded_idx_list + [self.pad_val] *
493498
(num_padding_tiles_val * self.tile_size),
494499
dtype=permuted_idx_to_expanded_idx.dtype,
495500
device=permuted_idx_to_expanded_idx.device)

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None,
174174
rOut_epi_packed[0].ir_value(),
175175
rOut_epi_packed[1].ir_value(),
176176
],
177-
"red.global.v2.f32.add [$0], {$1, $1};",
177+
"red.global.v2.f32.add [$0], {$1, $2};",
178178
"l,f,f",
179179
has_side_effects=True,
180180
)
@@ -190,7 +190,7 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
190190
rOut_epi_packed.ir_value(),
191191
],
192192
"red.global.add.f32 [$0], $1;",
193-
"=l,f",
193+
"l,f",
194194
has_side_effects=True,
195195
loc=loc,
196196
ip=ip,
@@ -498,6 +498,7 @@ def __call__(
498498
gemm_output_major: cutlass.Constexpr,
499499
tile_idx_to_expert_idx: cute.Tensor,
500500
num_non_exiting_tiles: cute.Tensor,
501+
tile_idx_to_mn_limit: cute.Tensor,
501502
alpha: cute.Tensor,
502503
max_active_clusters: cutlass.Constexpr,
503504
stream: cuda.CUstream,
@@ -739,6 +740,7 @@ class SharedStorage:
739740
out,
740741
tile_idx_to_expert_idx,
741742
num_non_exiting_tiles,
743+
tile_idx_to_mn_limit,
742744
alpha,
743745
permuted_idx_to_expanded_idx,
744746
token_final_scales,
@@ -821,6 +823,7 @@ def kernel(
821823
out: cute.Tensor,
822824
tile_idx_to_expert_idx: cute.Tensor,
823825
num_non_exiting_tiles: cute.Tensor,
826+
tile_idx_to_mn_limit: cute.Tensor,
824827
alpha: cute.Tensor,
825828
permuted_idx_to_expanded_idx: cute.Tensor,
826829
token_final_scales: cute.Tensor,
@@ -1612,7 +1615,9 @@ def kernel(
16121615
token_scale = self.final_scale_dtype(0.0)
16131616
topK = token_final_scales.shape[1]
16141617

1615-
if expanded_idx >= 0:
1618+
tile_mn_limit = tile_idx_to_mn_limit[mma_tile_coord_mnl[0]]
1619+
1620+
if permuted_row < tile_mn_limit:
16161621
token_idx = expanded_idx // topK
16171622
topk_idx = expanded_idx % topK
16181623
token_scale = token_final_scales[(token_idx, topk_idx)]
@@ -1652,24 +1657,25 @@ def kernel(
16521657

16531658
rOut_epi.store(acc_vec_finalized.to(self.out_dtype))
16541659

1655-
coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[
1656-
1
1657-
] + subtile_idx * cute.size(tTR_rAcc)
1658-
1659-
for index in cutlass.range(loop_size):
1660-
scatter_out_offset = cute.domain_offset((0, coord_n, 0), scatter_out)
1661-
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
1662-
rOut_epi_packed = rOut_epi[index, None, None]
1663-
vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset)
1664-
coord_n += cute.size(rOut_epi_packed)
1665-
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
1666-
rOut_epi_packed = rOut_epi[index, None]
1667-
vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset)
1668-
coord_n += cute.size(rOut_epi_packed)
1669-
else:
1670-
rOut_epi_packed = rOut_epi[index]
1671-
atomic_add_func(rOut_epi_packed, scatter_out_offset)
1672-
coord_n += 1
1660+
if permuted_row < tile_mn_limit:
1661+
coord_n = mma_tile_coord_mnl[1] * self.cta_tile_shape_mnk[
1662+
1
1663+
] + subtile_idx * cute.size(tTR_rAcc)
1664+
1665+
for index in cutlass.range(loop_size):
1666+
scatter_out_offset = cute.domain_offset((0, coord_n, 0), scatter_out)
1667+
if cutlass.const_expr(self.out_dtype == cutlass.BFloat16):
1668+
rOut_epi_packed = rOut_epi[index, None, None]
1669+
vectorized_atomic_add_bf16x8(rOut_epi_packed, scatter_out_offset)
1670+
coord_n += cute.size(rOut_epi_packed)
1671+
elif cutlass.const_expr(self.out_dtype == cutlass.Float32):
1672+
rOut_epi_packed = rOut_epi[index, None]
1673+
vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset)
1674+
coord_n += cute.size(rOut_epi_packed)
1675+
else:
1676+
rOut_epi_packed = rOut_epi[index]
1677+
atomic_add_func(rOut_epi_packed, scatter_out_offset)
1678+
coord_n += 1
16731679
self.epilog_sync_barrier.arrive_and_wait()
16741680
#
16751681
# Async arrive accumulator buffer empty
@@ -1697,10 +1703,6 @@ def kernel(
16971703
tmem.relinquish_alloc_permit()
16981704
self.epilog_sync_barrier.arrive_and_wait()
16991705
tmem.free(tmem_ptr)
1700-
#
1701-
# Wait for C store complete
1702-
#
1703-
# c_pipeline.producer_tail()
17041706

17051707
def epilog_tmem_copy_and_partition(
17061708
self,
@@ -1858,9 +1860,6 @@ def _compute_stages(
18581860
# Start with total smem per CTA (capacity / occupancy)
18591861
# Subtract reserved bytes and initial C stages bytes
18601862
# Divide remaining by bytes needed per A/B stage
1861-
# cute.printf("num_smem_capacity: {}, occupancy: {}, mbar_helpers_bytes: {}, c_bytes: {}", num_smem_capacity,
1862-
# occupancy, mbar_helpers_bytes, c_bytes)
1863-
# cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
18641863
num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes)) // ab_bytes_per_stage
18651864

18661865
# Refine epilogue stages:
@@ -2282,9 +2281,9 @@ def wrapper(
22822281
tile_idx_to_group_idx = cute.make_tensor(
22832282
tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,))
22842283
)
2285-
# tile_idx_to_mn_limit = cute.make_tensor(
2286-
# tile_idx_to_mn_limit_ptr, layout=cute.make_layout((num_tiles,))
2287-
# )
2284+
tile_idx_to_mn_limit = cute.make_tensor(
2285+
tile_idx_to_mn_limit_ptr, layout=cute.make_layout((num_tiles,))
2286+
)
22882287
permuted_idx_to_expanded_idx = cute.make_tensor(
22892288
permuted_idx_to_expanded_idx_ptr, layout=cute.make_layout((m,))
22902289
)
@@ -2305,6 +2304,7 @@ def wrapper(
23052304
"n",
23062305
tile_idx_to_group_idx,
23072306
num_non_exiting_tiles,
2307+
tile_idx_to_mn_limit,
23082308
alpha,
23092309
max_active_clusters=max_active_clusters,
23102310
stream=stream,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,6 @@ def forward_chunk_nvfp4(
412412
tile_tokens_dim=tile_size,
413413
)
414414

415-
# TODO: Remove this WAR.
416-
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
417-
permuted_idx_to_expanded_idx.masked_fill_(
418-
torch.arange(max_num_permuted_tokens, device='cuda')
419-
>= tile_idx_to_mn_limit.repeat_interleave(tile_size), -1)
420-
x_sf.masked_fill_(x_sf.view(torch.float8_e4m3fn).isnan(), 0)
421-
422415
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
423416
input=x.view(torch.float4_e2m1fn_x2),
424417
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),

tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -467,14 +467,7 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
467467
tile_tokens_dim=tile_size,
468468
)
469469

470-
# TODO: Remove this WAR.
471470
max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0)
472-
permuted_idx_to_expanded_idx.masked_fill_(
473-
torch.arange(max_num_permuted_tokens, device="cuda")
474-
>= tile_idx_to_mn_limit.repeat_interleave(tile_size),
475-
-1,
476-
)
477-
478471
a = torch.randint(
479472
-100, 100, (max_num_permuted_tokens, hidden_size // 2), dtype=torch.int32, device="cuda"
480473
)

0 commit comments

Comments
 (0)