@@ -27,7 +27,7 @@ def atomic_add_2d_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default
27
27
"""Test atomic_add with 2D indexing."""
28
28
_BLOCK_SIZE_0 = 8
29
29
_BLOCK_SIZE_1 = 8
30
- _launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3 )
30
+ _launcher(_helion_atomic_add_2d_kernel, (triton.cdiv(y.size(0), _BLOCK_SIZE_0) * triton.cdiv(y.size(1), _BLOCK_SIZE_1),), y, x, y.size(0), y.size(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=2 )
31
31
return x
32
32
33
33
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_1d_tensor)
@@ -59,7 +59,7 @@ def atomic_add_1d_tensor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_
59
59
z = torch.zeros([n], dtype=x.dtype, device=x.device)
60
60
_BLOCK_SIZE_0 = 32
61
61
_RDIM_SIZE_1 = 64
62
- _launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3 )
62
+ _launcher(_helion_atomic_add_1d_tensor_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, y, z, x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=2 )
63
63
return z
64
64
65
65
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_float)
@@ -82,7 +82,7 @@ def _helion_atomic_add_float_kernel(indices, x, indices_size_0, indices_stride_0
82
82
def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
83
83
"""Test atomic_add with a float constant value and reading from lookup"""
84
84
_BLOCK_SIZE_0 = 32
85
- _launcher(_helion_atomic_add_float_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
85
+ _launcher(_helion_atomic_add_float_kernel, (triton.cdiv(indices.size(0), _BLOCK_SIZE_0),), indices, x, indices.size(0), indices.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
86
86
return x
87
87
88
88
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_returns_prev)
@@ -106,7 +106,7 @@ def _helion_k(x, y, prev, x_size_0, prev_stride_0, x_stride_0, y_stride_0, _BLOC
106
106
def k(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
107
107
prev = torch.empty_like(x)
108
108
_BLOCK_SIZE_0 = 8
109
- _launcher(_helion_k, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, prev, x.size(0), prev.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
109
+ _launcher(_helion_k, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, prev, x.size(0), prev.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
110
110
return (x, prev)
111
111
112
112
--- assertExpectedJournal(TestAtomicOperations.test_atomic_add_w_tile_attr)
@@ -127,7 +127,7 @@ def atomic_add_w_tile_attr(x: torch.Tensor, *, _launcher=_default_launcher):
127
127
"""Test atomic_add where the index is a symbolic int"""
128
128
y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
129
129
_BLOCK_SIZE_0 = 2
130
- _launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
130
+ _launcher(_helion_atomic_add_w_tile_attr, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), y, y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
131
131
return y
132
132
133
133
--- assertExpectedJournal(TestAtomicOperations.test_atomic_and)
@@ -149,7 +149,7 @@ def _helion_atomic_and_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
149
149
150
150
def atomic_and_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
151
151
_BLOCK_SIZE_0 = 8
152
- _launcher(_helion_atomic_and_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
152
+ _launcher(_helion_atomic_and_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
153
153
return x
154
154
155
155
--- assertExpectedJournal(TestAtomicOperations.test_atomic_cas)
@@ -172,7 +172,7 @@ def _helion_atomic_cas_kernel(x, expect, y, x_size_0, expect_stride_0, x_stride_
172
172
173
173
def atomic_cas_kernel(x: torch.Tensor, y: torch.Tensor, expect: torch.Tensor, *, _launcher=_default_launcher):
174
174
_BLOCK_SIZE_0 = 4
175
- _launcher(_helion_atomic_cas_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, expect, y, x.size(0), expect.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
175
+ _launcher(_helion_atomic_cas_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, expect, y, x.size(0), expect.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
176
176
return x
177
177
178
178
--- assertExpectedJournal(TestAtomicOperations.test_atomic_max)
@@ -194,7 +194,7 @@ def _helion_atomic_max_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
194
194
195
195
def atomic_max_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
196
196
_BLOCK_SIZE_0 = 4
197
- _launcher(_helion_atomic_max_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
197
+ _launcher(_helion_atomic_max_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
198
198
return x
199
199
200
200
--- assertExpectedJournal(TestAtomicOperations.test_atomic_min)
@@ -216,7 +216,7 @@ def _helion_atomic_min_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
216
216
217
217
def atomic_min_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
218
218
_BLOCK_SIZE_0 = 4
219
- _launcher(_helion_atomic_min_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
219
+ _launcher(_helion_atomic_min_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
220
220
return x
221
221
222
222
--- assertExpectedJournal(TestAtomicOperations.test_atomic_or)
@@ -238,7 +238,7 @@ def _helion_atomic_or_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE
238
238
239
239
def atomic_or_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
240
240
_BLOCK_SIZE_0 = 8
241
- _launcher(_helion_atomic_or_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
241
+ _launcher(_helion_atomic_or_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
242
242
return x
243
243
244
244
--- assertExpectedJournal(TestAtomicOperations.test_atomic_xchg)
@@ -260,7 +260,7 @@ def _helion_atomic_xchg_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SI
260
260
261
261
def atomic_xchg_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
262
262
_BLOCK_SIZE_0 = 8
263
- _launcher(_helion_atomic_xchg_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
263
+ _launcher(_helion_atomic_xchg_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
264
264
return x
265
265
266
266
--- assertExpectedJournal(TestAtomicOperations.test_atomic_xor)
@@ -282,7 +282,7 @@ def _helion_atomic_xor_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
282
282
283
283
def atomic_xor_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
284
284
_BLOCK_SIZE_0 = 8
285
- _launcher(_helion_atomic_xor_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
285
+ _launcher(_helion_atomic_xor_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
286
286
return x
287
287
288
288
--- assertExpectedJournal(TestAtomicOperations.test_basic_atomic_add)
@@ -305,7 +305,7 @@ def _helion_atomic_add_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZ
305
305
def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
306
306
"""Test basic atomic_add functionality."""
307
307
_BLOCK_SIZE_0 = 32
308
- _launcher(_helion_atomic_add_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
308
+ _launcher(_helion_atomic_add_kernel, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
309
309
return x
310
310
311
311
--- assertExpectedJournal(TestAtomicOperations.test_overlapping_atomic_add)
@@ -329,5 +329,5 @@ def _helion_atomic_add_overlap_kernel(indices, y, x, _BLOCK_SIZE_0: tl.constexpr
329
329
def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher):
330
330
"""Test atomic_add with overlapping indices."""
331
331
_BLOCK_SIZE_0 = 32
332
- _launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=3 )
332
+ _launcher(_helion_atomic_add_overlap_kernel, (triton.cdiv(10, _BLOCK_SIZE_0),), indices, y, x, _BLOCK_SIZE_0, num_warps=4, num_stages=2 )
333
333
return x
0 commit comments