@@ -174,7 +174,7 @@ def vectorized_atomic_add_fp32x2(rOut_epi_packed, scatter_out_offset, loc=None,
174174 rOut_epi_packed [0 ].ir_value (),
175175 rOut_epi_packed [1 ].ir_value (),
176176 ],
177- "red.global.v2.f32.add [$0], {$1, $1 };" ,
177+ "red.global.v2.f32.add [$0], {$1, $2 };" ,
178178 "l,f,f" ,
179179 has_side_effects = True ,
180180 )
@@ -190,7 +190,7 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
190190 rOut_epi_packed .ir_value (),
191191 ],
192192 "red.global.add.f32 [$0], $1;" ,
193- "= l,f" ,
193+ "l,f" ,
194194 has_side_effects = True ,
195195 loc = loc ,
196196 ip = ip ,
@@ -498,6 +498,7 @@ def __call__(
498498 gemm_output_major : cutlass .Constexpr ,
499499 tile_idx_to_expert_idx : cute .Tensor ,
500500 num_non_exiting_tiles : cute .Tensor ,
501+ tile_idx_to_mn_limit : cute .Tensor ,
501502 alpha : cute .Tensor ,
502503 max_active_clusters : cutlass .Constexpr ,
503504 stream : cuda .CUstream ,
@@ -739,6 +740,7 @@ class SharedStorage:
739740 out ,
740741 tile_idx_to_expert_idx ,
741742 num_non_exiting_tiles ,
743+ tile_idx_to_mn_limit ,
742744 alpha ,
743745 permuted_idx_to_expanded_idx ,
744746 token_final_scales ,
@@ -821,6 +823,7 @@ def kernel(
821823 out : cute .Tensor ,
822824 tile_idx_to_expert_idx : cute .Tensor ,
823825 num_non_exiting_tiles : cute .Tensor ,
826+ tile_idx_to_mn_limit : cute .Tensor ,
824827 alpha : cute .Tensor ,
825828 permuted_idx_to_expanded_idx : cute .Tensor ,
826829 token_final_scales : cute .Tensor ,
@@ -1612,7 +1615,9 @@ def kernel(
16121615 token_scale = self .final_scale_dtype (0.0 )
16131616 topK = token_final_scales .shape [1 ]
16141617
1615- if expanded_idx >= 0 :
1618+ tile_mn_limit = tile_idx_to_mn_limit [mma_tile_coord_mnl [0 ]]
1619+
1620+ if permuted_row < tile_mn_limit :
16161621 token_idx = expanded_idx // topK
16171622 topk_idx = expanded_idx % topK
16181623 token_scale = token_final_scales [(token_idx , topk_idx )]
@@ -1652,24 +1657,25 @@ def kernel(
16521657
16531658 rOut_epi .store (acc_vec_finalized .to (self .out_dtype ))
16541659
1655- coord_n = mma_tile_coord_mnl [1 ] * self .cta_tile_shape_mnk [
1656- 1
1657- ] + subtile_idx * cute .size (tTR_rAcc )
1658-
1659- for index in cutlass .range (loop_size ):
1660- scatter_out_offset = cute .domain_offset ((0 , coord_n , 0 ), scatter_out )
1661- if cutlass .const_expr (self .out_dtype == cutlass .BFloat16 ):
1662- rOut_epi_packed = rOut_epi [index , None , None ]
1663- vectorized_atomic_add_bf16x8 (rOut_epi_packed , scatter_out_offset )
1664- coord_n += cute .size (rOut_epi_packed )
1665- elif cutlass .const_expr (self .out_dtype == cutlass .Float32 ):
1666- rOut_epi_packed = rOut_epi [index , None ]
1667- vectorized_atomic_add_fp32x2 (rOut_epi_packed , scatter_out_offset )
1668- coord_n += cute .size (rOut_epi_packed )
1669- else :
1670- rOut_epi_packed = rOut_epi [index ]
1671- atomic_add_func (rOut_epi_packed , scatter_out_offset )
1672- coord_n += 1
1660+ if permuted_row < tile_mn_limit :
1661+ coord_n = mma_tile_coord_mnl [1 ] * self .cta_tile_shape_mnk [
1662+ 1
1663+ ] + subtile_idx * cute .size (tTR_rAcc )
1664+
1665+ for index in cutlass .range (loop_size ):
1666+ scatter_out_offset = cute .domain_offset ((0 , coord_n , 0 ), scatter_out )
1667+ if cutlass .const_expr (self .out_dtype == cutlass .BFloat16 ):
1668+ rOut_epi_packed = rOut_epi [index , None , None ]
1669+ vectorized_atomic_add_bf16x8 (rOut_epi_packed , scatter_out_offset )
1670+ coord_n += cute .size (rOut_epi_packed )
1671+ elif cutlass .const_expr (self .out_dtype == cutlass .Float32 ):
1672+ rOut_epi_packed = rOut_epi [index , None ]
1673+ vectorized_atomic_add_fp32x2 (rOut_epi_packed , scatter_out_offset )
1674+ coord_n += cute .size (rOut_epi_packed )
1675+ else :
1676+ rOut_epi_packed = rOut_epi [index ]
1677+ atomic_add_func (rOut_epi_packed , scatter_out_offset )
1678+ coord_n += 1
16731679 self .epilog_sync_barrier .arrive_and_wait ()
16741680 #
16751681 # Async arrive accumulator buffer empty
@@ -1697,10 +1703,6 @@ def kernel(
16971703 tmem .relinquish_alloc_permit ()
16981704 self .epilog_sync_barrier .arrive_and_wait ()
16991705 tmem .free (tmem_ptr )
1700- #
1701- # Wait for C store complete
1702- #
1703- # c_pipeline.producer_tail()
17041706
17051707 def epilog_tmem_copy_and_partition (
17061708 self ,
@@ -1858,9 +1860,6 @@ def _compute_stages(
18581860 # Start with total smem per CTA (capacity / occupancy)
18591861 # Subtract reserved bytes and initial C stages bytes
18601862 # Divide remaining by bytes needed per A/B stage
1861- # cute.printf("num_smem_capacity: {}, occupancy: {}, mbar_helpers_bytes: {}, c_bytes: {}", num_smem_capacity,
1862- # occupancy, mbar_helpers_bytes, c_bytes)
1863- # cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage)
18641863 num_ab_stage = (num_smem_capacity // occupancy - (mbar_helpers_bytes )) // ab_bytes_per_stage
18651864
18661865 # Refine epilogue stages:
@@ -2282,9 +2281,9 @@ def wrapper(
22822281 tile_idx_to_group_idx = cute .make_tensor (
22832282 tile_idx_to_group_idx_ptr , layout = cute .make_layout ((num_tiles ,))
22842283 )
2285- # tile_idx_to_mn_limit = cute.make_tensor(
2286- # tile_idx_to_mn_limit_ptr, layout=cute.make_layout((num_tiles,))
2287- # )
2284+ tile_idx_to_mn_limit = cute .make_tensor (
2285+ tile_idx_to_mn_limit_ptr , layout = cute .make_layout ((num_tiles ,))
2286+ )
22882287 permuted_idx_to_expanded_idx = cute .make_tensor (
22892288 permuted_idx_to_expanded_idx_ptr , layout = cute .make_layout ((m ,))
22902289 )
@@ -2305,6 +2304,7 @@ def wrapper(
23052304 "n" ,
23062305 tile_idx_to_group_idx ,
23072306 num_non_exiting_tiles ,
2307+ tile_idx_to_mn_limit ,
23082308 alpha ,
23092309 max_active_clusters = max_active_clusters ,
23102310 stream = stream ,
0 commit comments