@@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
149
149
let summary = "tile zero operation";
150
150
let description = [{
151
151
Zeroes the destination tile, with the shape defined by the 2-dim
152
- vector type of the result. This is eventually lowered into the
153
- "tilezero" instruction with the corresponding tile configuration.
154
- With memory-effects, each "tilezero" operation serves as a compilation
155
- hint to use a separate tile register.
152
+ vector type of the result.
153
+
154
+ The operation is eventually lowered into the "tilezero" instruction
155
+ with the corresponding tile configuration.
156
+
157
+ With the write memory effect, each `amx.tile_zero` operation serves as
158
+ a compilation hint to use a separate tile register.
156
159
157
160
Example:
158
161
@@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [
184
187
185
188
def TileLoadOp : AMX_Op<"tile_load", [
186
189
AMXIntrinsicOpInterface,
187
- MemoryEffects<[MemWrite]>
190
+ MemoryEffects<[MemWrite]>,
191
+ AttrSizedOperandSegments
188
192
]> {
189
193
let summary = "tile load operation";
190
194
let description = [{
191
- Loads a tile from memory defined by a base and indices, with the
192
- shape defined by the 2-dim vector type of the result. This is
193
- eventually lowered into the "tileloadd" instruction with the
194
- corresponding tile configuration. With memory-effects, each "tileload"
195
- operation serves as a compilation hint to use a separate tile register.
195
+ Loads a tile from memory defined by a `base` and `indices`, with the
196
+ shape defined by the 2-dim vector type of the result.
197
+ The tile's rows are populated by reading contiguous elements starting
198
+ at the `base`. For each tile row, the `base` is incremented by `stride`
199
+ number of elements.
200
+
201
+ The tile is loaded using the following indexing scheme:
202
+
203
+ ```
204
+ for row in enumerate(tile_rows):
205
+ mem_row = base[i0, i1, ..., iN + row * stride]
206
+ for col in enumerate(tile_cols):
207
+ tile[row, col] = mem_row[col]
208
+ ```
209
+
210
+ If the `stride` is not provided, then the `base` buffer must be at least
211
+ 2-dimensional, and the `stride` is automatically inferred and corresponds
212
+ to the stride of the buffer's second innermost dimension.
213
+
214
+ The operation is eventually lowered into the "tileloadd" instruction
215
+ with the corresponding tile configuration.
216
+
217
+ With the write memory effect, each `amx.tile_load` operation serves as
218
+ a compilation hint to use a separate tile register.
196
219
197
220
Example:
198
221
199
222
```mlir
223
+ // Tile load from a 2-D memref with implicit stride.
200
224
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
225
+
226
+ // Tile load from a 1-D memref with explicit stride.
227
+ %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
201
228
```
202
229
}];
203
230
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
204
- Variadic<Index>:$indices);
231
+ Variadic<Index>:$indices,
232
+ Optional<Index>:$stride);
205
233
let results = (outs AnyAMXTile:$res);
234
+ let builders = [
235
+ OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
236
+ ];
206
237
let extraClassDeclaration = [{
207
238
MemRefType getMemRefType() {
208
239
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
219
250
const ::mlir::LLVMTypeConverter &typeConverter,
220
251
::mlir::RewriterBase &rewriter);
221
252
}];
222
- let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
223
- "type($base) `into` qualified(type($res))";
253
+ let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict "
254
+ "`:` type($base) `into` qualified(type($res))";
224
255
let hasVerifier = 1;
225
256
}
226
257
227
258
def TileStoreOp : AMX_Op<"tile_store", [
228
- AMXIntrinsicOpInterface
259
+ AMXIntrinsicOpInterface,
260
+ AttrSizedOperandSegments
229
261
]> {
230
262
let summary = "tile store operation";
231
263
let description = [{
232
- Stores a tile to memory defined by a base and indices, with the
233
- shape defined by the 2-dim vector type of the value. This is
234
- eventually lowered into the "tilestored" instruction with the
235
- corresponding tile configuration.
264
+ Stores a tile to memory defined by a `base` and `indices`, with the
265
+ shape defined by the 2-dim vector type of the value.
266
+ The tile's rows are written contiguously to the buffer starting at
267
+ the `base`. For each tile row, the `base` is incremented by `stride`
268
+ number of elements.
269
+
270
+ The tile is stored using the following indexing scheme:
271
+
272
+ ```
273
+ for row in enumerate(tile_rows):
274
+ mem_row = base[i0, i1, ..., iN + row * stride]
275
+ for col in enumerate(tile_cols):
276
+ mem_row[col] = tile[row, col]
277
+ ```
278
+
279
+ If the `stride` is not provided, then the `base` buffer must be at least
280
+ 2-dimensional, and the `stride` is automatically inferred and corresponds
281
+ to the stride of the buffer's second innermost dimension.
282
+
283
+ The operation is eventually lowered into the "tilestored" instruction
284
+ with the corresponding tile configuration.
236
285
237
286
Example:
238
287
239
288
```mlir
289
+ // Tile store to a 2-D memref with implicit stride.
240
290
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
291
+
292
+ // Tile store to a 1-D memref with explicit stride.
293
+ amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
241
294
```
242
295
}];
243
296
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
244
297
Variadic<Index>:$indices,
245
- AnyAMXTile:$val);
298
+ AnyAMXTile:$val,
299
+ Optional<Index>:$stride);
300
+ let builders = [
301
+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
302
+ ];
246
303
let extraClassDeclaration = [{
247
304
MemRefType getMemRefType() {
248
305
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
259
316
const ::mlir::LLVMTypeConverter &typeConverter,
260
317
::mlir::RewriterBase &rewriter);
261
318
}];
262
- let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
263
- "type($base) `,` qualified(type($val))";
319
+ let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )? "
320
+ "attr-dict `:` type($base) `,` qualified(type($val))";
264
321
let hasVerifier = 1;
265
322
}
266
323
@@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
276
333
let description = [{
277
334
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
278
335
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
279
- pairs of "bf16"). The operation is eventually lowered into the
280
- "tdpbf16ps" instruction with the corresponding tile configuration.
336
+ pairs of "bf16").
337
+
338
+ The operation is eventually lowered into the "tdpbf16ps" instruction with
339
+ the corresponding tile configuration.
281
340
282
341
Example:
283
342
@@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
330
389
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
331
390
combinations (4 bytes packed into dwords in the columns of both the
332
391
source operand tiles; the zero or sign extension is specified with
333
- the attributes and default to sign extended). The operation is eventually
334
- lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
335
- instructions with the corresponding tile configuration.
392
+ the attributes and default to sign extended).
393
+
394
+ The operation is eventually lowered into one of the "tdpbssd",
395
+ "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
396
+ tile configuration.
336
397
337
398
Example:
338
399
0 commit comments