Skip to content

Commit 0361edf

Browse files
authored
Decrease num_stages default from 3 to 2, to avoid shared memory OOM (#841)
1 parent 70e9182 commit 0361edf

33 files changed

+424
-424
lines changed

helion/autotuner/config_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..runtime.config import PidTypeLiteral
3434

3535
DEFAULT_NUM_WARPS = 4
36-
DEFAULT_NUM_STAGES = 3
36+
DEFAULT_NUM_STAGES = 2
3737
VALID_KEYS: frozenset[str] = frozenset(
3838
[
3939
"block_sizes",

test/test_associative_scan.expected

Lines changed: 42 additions & 42 deletions
Large diffs are not rendered by default.

test/test_atomic_ops.expected

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default
2727
"""Test atomic_add with 2D indexing."""
2828
_BLOCK_SIZE_0 = 8
2929
_BLOCK_SIZE_1 = 8
30-
_launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
30+
_launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
3131
return x
3232

3333
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_1d_tensor)
@@ -59,7 +59,7 @@ def atomic_add_1d_tensor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_
5959
z = torch.zeros([n], dtype=x.dtype, device=x.device)
6060
_BLOCK_SIZE_0 = 32
6161
_RDIM_SIZE_1 = 64
62-
_launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
62+
_launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2)
6363
return z
6464

6565
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_float)
@@ -82,7 +82,7 @@ def _helion_atomic_add_float_kernel(indices, x, indices_size_0, indices_stride_0
8282
def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
8383
"""Test atomic_add with a float constant value and reading from lookup"""
8484
_BLOCK_SIZE_0 = 32
85-
_launcher(_helion_atomic_add_float_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
85+
_launcher(_helion_atomic_add_float_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
8686
return x
8787

8888
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_returns_prev)
@@ -106,7 +106,7 @@ def _helion_k(x, y, prev, x_size_0, prev_stride_0, x_stride_0, y_stride_0, _BLOC
106106
def k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
107107
prev = torch.empty_like(x)
108108
_BLOCK_SIZE_0 = 8
109-
_launcher(_helion_k, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, prev, x.size(0), prev.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
109+
_launcher(_helion_k, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, prev, x.size(0), prev.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
110110
return (x, prev)
111111

112112
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_w_tile_attr)
@@ -127,7 +127,7 @@ def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_launcher):
127127
"""Test atomic_add where the index is a symbolic int"""
128128
y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
129129
_BLOCK_SIZE_0 = 2
130-
_launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
130+
_launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
131131
return y
132132

133133
--- assertExpectedJournal(TestAtomicOperations.test_atomic_and)
@@ -149,7 +149,7 @@ def _helion_atomic_and_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
149149

150150
def atomic_and_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
151151
_BLOCK_SIZE_0 = 8
152-
_launcher(_helion_atomic_and_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
152+
_launcher(_helion_atomic_and_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
153153
return x
154154

155155
--- assertExpectedJournal(TestAtomicOperations.test_atomic_cas)
@@ -172,7 +172,7 @@ def _helion_atomic_cas_kernel(x, expect, y, x_size_0, expect_stride_0, x_stride_
172172

173173
def atomic_cas_kernel(x: torch.Tensor, y: torch.Tensor, expect: torch.Tensor, *, _launcher=_default_launcher):
174174
_BLOCK_SIZE_0 = 4
175-
_launcher(_helion_atomic_cas_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, expect, y, x.size(0), expect.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
175+
_launcher(_helion_atomic_cas_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, expect, y, x.size(0), expect.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
176176
return x
177177

178178
--- assertExpectedJournal(TestAtomicOperations.test_atomic_max)
@@ -194,7 +194,7 @@ def _helion_atomic_max_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
194194

195195
def atomic_max_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
196196
_BLOCK_SIZE_0 = 4
197-
_launcher(_helion_atomic_max_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
197+
_launcher(_helion_atomic_max_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
198198
return x
199199

200200
--- assertExpectedJournal(TestAtomicOperations.test_atomic_min)
@@ -216,7 +216,7 @@ def _helion_atomic_min_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
216216

217217
def atomic_min_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
218218
_BLOCK_SIZE_0 = 4
219-
_launcher(_helion_atomic_min_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
219+
_launcher(_helion_atomic_min_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
220220
return x
221221

222222
--- assertExpectedJournal(TestAtomicOperations.test_atomic_or)
@@ -238,7 +238,7 @@ def _helion_atomic_or_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE
238238

239239
def atomic_or_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
240240
_BLOCK_SIZE_0 = 8
241-
_launcher(_helion_atomic_or_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
241+
_launcher(_helion_atomic_or_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
242242
return x
243243

244244
--- assertExpectedJournal(TestAtomicOperations.test_atomic_xchg)
@@ -260,7 +260,7 @@ def _helion_atomic_xchg_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SI
260260

261261
def atomic_xchg_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
262262
_BLOCK_SIZE_0 = 8
263-
_launcher(_helion_atomic_xchg_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
263+
_launcher(_helion_atomic_xchg_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
264264
return x
265265

266266
--- assertExpectedJournal(TestAtomicOperations.test_atomic_xor)
@@ -282,7 +282,7 @@ def _helion_atomic_xor_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
282282

283283
def atomic_xor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
284284
_BLOCK_SIZE_0 = 8
285-
_launcher(_helion_atomic_xor_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
285+
_launcher(_helion_atomic_xor_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
286286
return x
287287

288288
--- assertExpectedJournal(TestAtomicOperations.test_basic_atomic_add)
@@ -305,7 +305,7 @@ def _helion_atomic_add_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
305305
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
306306
"""Test basic atomic_add functionality."""
307307
_BLOCK_SIZE_0 = 32
308-
_launcher(_helion_atomic_add_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
308+
_launcher(_helion_atomic_add_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
309309
return x
310310

311311
--- assertExpectedJournal(TestAtomicOperations.test_overlapping_atomic_add)
@@ -329,5 +329,5 @@ def _helion_atomic_add_overlap_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr
329329
def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
330330
"""Test atomic_add with overlapping indices."""
331331
_BLOCK_SIZE_0 = 32
332-
_launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
332+
_launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
333333
return x

test/test_autotuner.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_auto
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

44
--- assertExpectedJournal(TestAutotuner.test_config_fragment0)
5-
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
5+
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
66
helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True])
77
helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False])
88
helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 3], range_warp_specializes=[True, None])
@@ -14,7 +14,7 @@ helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], l
1414
helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 4], range_warp_specializes=[False, True])
1515

1616
--- assertExpectedJournal(TestAutotuner.test_config_fragment1)
17-
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=3, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
17+
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
1818
helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True])
1919
helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
2020
helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[True])

test/test_broadcasting.expected

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
3434
out1 = torch.empty_like(a)
3535
_BLOCK_SIZE_0 = 16
3636
_BLOCK_SIZE_1 = 8
37-
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
37+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
3838
return (out0, out1)
3939

4040
--- assertExpectedJournal(TestBroadcasting.test_broadcast2)
@@ -70,7 +70,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
7070
out1 = torch.empty_like(a)
7171
_BLOCK_SIZE_1 = 8
7272
_BLOCK_SIZE_0 = 16
73-
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
73+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(1), _BLOCK_SIZE_1) * triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
7474
return (out0, out1)
7575

7676
--- assertExpectedJournal(TestBroadcasting.test_broadcast3)
@@ -104,7 +104,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
104104
out0 = torch.empty_like(a)
105105
out1 = torch.empty_like(a)
106106
_BLOCK_SIZE_0 = 64
107-
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
107+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * a.size(1),), a, b, out0, out1, a.size(0), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=2)
108108
return (out0, out1)
109109

110110
--- assertExpectedJournal(TestBroadcasting.test_broadcast4)
@@ -138,7 +138,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
138138
out0 = torch.empty_like(a)
139139
out1 = torch.empty_like(a)
140140
_BLOCK_SIZE_1 = 64
141-
_launcher(_helion_broadcast_fn, (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=3)
141+
_launcher(_helion_broadcast_fn, (a.size(0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_1, num_warps=4, num_stages=2)
142142
return (out0, out1)
143143

144144
--- assertExpectedJournal(TestBroadcasting.test_broadcast5)
@@ -170,7 +170,7 @@ def broadcast_fn(a, b, *, _launcher=_default_launcher):
170170
out1 = torch.empty_like(a)
171171
_BLOCK_SIZE_0 = 32
172172
_BLOCK_SIZE_1 = 32
173-
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
173+
_launcher(_helion_broadcast_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out0, out1, a.size(0), a.size(1), b.size(0), out0.size(0), out0.size(1), out1.size(0), out1.size(1), a.stride(0), a.stride(1), b.stride(0), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
174174
return (out0, out1)
175175

176176
--- assertExpectedJournal(TestBroadcasting.test_constexpr_index)
@@ -212,7 +212,7 @@ def fn(a, idx1, *, _launcher=_default_launcher):
212212
out2 = torch.empty_like(a)
213213
_BLOCK_SIZE_0 = 16
214214
_BLOCK_SIZE_1 = 16
215-
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
215+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, out0, out1, out2, a.size(0), a.size(1), a.stride(0), a.stride(1), out0.stride(0), out0.stride(1), out1.stride(0), out1.stride(1), out2.stride(0), out2.stride(1), idx1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
216216
return (out0, out1, out2)
217217

218218
--- assertExpectedJournal(TestBroadcasting.test_implicit_broadcast)
@@ -244,7 +244,7 @@ def fn(a, b, *, _launcher=_default_launcher):
244244
out = torch.empty_like(a)
245245
_BLOCK_SIZE_0 = 16
246246
_BLOCK_SIZE_1 = 16
247-
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
247+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0) * triton.cdiv(a.size(1), _BLOCK_SIZE_1),), a, b, out, a.size(0), a.size(1), a.stride(0), a.stride(1), b.stride(0), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2)
248248
return out
249249

250250
--- assertExpectedJournal(TestBroadcasting.test_python_float_promotion)
@@ -265,5 +265,5 @@ def _helion_fn(a, a_size_0, a_stride_0, beta, _BLOCK_SIZE_0: tl.constexpr):
265265

266266
def fn(a, beta, *, _launcher=_default_launcher):
267267
_BLOCK_SIZE_0 = 16
268-
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, a.size(0), a.stride(0), beta, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
268+
_launcher(_helion_fn, (triton.cdiv(a.size(0), _BLOCK_SIZE_0),), a, a.size(0), a.stride(0), beta, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
269269
return a

0 commit comments

Comments
 (0)