Skip to content

Commit afa8452

Browse files
tc-oaiwdziurdz
authored andcommitted
Remove redundant reduce for topk=1 (#8647)
1 parent f3afc0b commit afa8452

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def init_allocation(x, w, precision_config, fused_activation,
253253
# ---- scratchpad -----#
254254
scratchpad = dict()
255255
N_scratch = N // fused_activation.specs.reduction_n if opt_flags.split_k == 1 else N
256-
if opt_flags.split_k > 1 or scatter_indx is not None:
256+
if opt_flags.split_k > 1 or (scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1)):
257257
scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
258258
scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N_scratch), scratch_out_dtype)
259259
if "matmul" in scratchpad and precision_config.out_scale is not None:
@@ -654,7 +654,7 @@ def matmul_ogs(x, w, bias,
654654
if y_mx_scale is not None:
655655
out_final_mx_scale = y_mx_scale.view(out_matmul.shape[-2], triton.cdiv(out_matmul.shape[-1], 32))
656656
# TODO: change `matmul_ogs` semantics and move this to another op!
657-
if scatter_indx is not None:
657+
if scatter_indx is not None and (not is_cuda() or routing_data.n_expts_act > 1): # Matmul ogs kernel fuses scatter already, so only need for n_exps_act > 1.
658658
mask = (scatter_indx.src_indx != -1).view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, 1)
659659
out_matmul = out_matmul.view(out_matmul.shape[-2]//routing_data.n_expts_act, routing_data.n_expts_act, -1)
660660
mask = mask.expand_as(out_matmul)

0 commit comments

Comments
 (0)