Skip to content

Commit cf4f254

Browse files
committed
Use helion.cdiv
stack-info: PR: #852, branch: oulgen/stack/129
1 parent d847410 commit cf4f254

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

examples/grouped_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def grouped_gemm_jagged_persistent(
165165

166166
if m_size > 0:
167167
# Compute tile grid dimensions for current group
168-
num_m_tiles = (m_size + BLOCK_M - 1) // BLOCK_M
168+
num_m_tiles = helion.cdiv(m_size, BLOCK_M) # pyright: ignore[reportArgumentType]
169169
# Calculate number of N tiles (shared across all groups)
170-
num_n_tiles = (N + BLOCK_N - 1) // BLOCK_N
170+
num_n_tiles = helion.cdiv(N, BLOCK_N) # pyright: ignore[reportArgumentType]
171171
num_group_tiles = num_m_tiles * num_n_tiles
172172

173173
# Distribute tiles among workers using strided access pattern

examples/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def rms_norm_bwd(
9696
m_block = hl.register_block_size(x.size(0))
9797
grad_x = torch.empty_like(x)
9898
grad_weight = x.new_empty(
99-
[(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32
99+
[helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32
100100
)
101101
weight_shape = hl.specialize(weight.size(0))
102102
for mb_cta in hl.tile(x.size(0), block_size=m_block):

test/test_examples.expected

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3153,6 +3153,7 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.
31533153
from __future__ import annotations
31543154

31553155
import torch
3156+
import helion
31563157
import triton
31573158
import triton.language as tl
31583159
from helion.runtime import default_launcher as _default_launcher
@@ -3218,7 +3219,7 @@ def rms_norm_bwd(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
32183219
"""
32193220
m_block = 32
32203221
grad_x = torch.empty_like(x)
3221-
grad_weight = x.new_empty([(x.size(0) + m_block - 1) // m_block, *weight.shape], dtype=torch.float32)
3222+
grad_weight = x.new_empty([helion.cdiv(x.size(0), m_block), *weight.shape], dtype=torch.float32)
32223223
_BLOCK_SIZE_0 = 32
32233224
_RDIM_SIZE_2 = 64
32243225
_launcher(_helion_rms_norm_bwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, grad_out, rsqrt, weight, grad_x, grad_weight, x.size(0), grad_out.stride(0), grad_out.stride(1), grad_weight.stride(0), grad_weight.stride(1), grad_x.stride(0), grad_x.stride(1), rsqrt.stride(0), weight.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_2, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)