Skip to content

Commit 22bb27e

Browse files
committed
[mlir][amx] Optional stride for tile load and store
Adds an optional stride argument to `amx.tile_load` and `amx.tile_store` operations. The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior. Explicit stride allows greater flexibility in terms of the base buffer shape (enables usage of 1D memrefs) and allows different read and write memory patterns. When stride is not provided, it is inferred from the buffer shape. Operations documentation is expanded to make ops easier to use.
1 parent 6b19ccd commit 22bb27e

File tree

5 files changed

+257
-84
lines changed

5 files changed

+257
-84
lines changed

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
149149
let summary = "tile zero operation";
150150
let description = [{
151151
Zeroes the destination tile, with the shape defined by the 2-dim
152-
vector type of the result. This is eventually lowered into the
153-
"tilezero" instruction with the corresponding tile configuration.
154-
With memory-effects, each "tilezero" operation serves as a compilation
155-
hint to use a separate tile register.
152+
vector type of the result.
153+
154+
The operation is eventually lowered into the "tilezero" instruction
155+
with the corresponding tile configuration.
156+
157+
With the write memory effect, each `amx.tile_zero` operation serves as
158+
a compilation hint to use a separate tile register.
156159

157160
Example:
158161

@@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [
184187

185188
def TileLoadOp : AMX_Op<"tile_load", [
186189
AMXIntrinsicOpInterface,
187-
MemoryEffects<[MemWrite]>
190+
MemoryEffects<[MemWrite]>,
191+
AttrSizedOperandSegments
188192
]> {
189193
let summary = "tile load operation";
190194
let description = [{
191-
Loads a tile from memory defined by a base and indices, with the
192-
shape defined by the 2-dim vector type of the result. This is
193-
eventually lowered into the "tileloadd" instruction with the
194-
corresponding tile configuration. With memory-effects, each "tileload"
195-
operation serves as a compilation hint to use a separate tile register.
195+
Loads a tile from memory defined by a `base` and `indices`, with the
196+
shape defined by the 2-dim vector type of the result.
197+
The tile's rows are populated by reading contiguous elements starting
198+
at the `base`. For each tile row, the `base` is incremented by `stride`
199+
number of elements.
200+
201+
The tile is loaded using the following indexing scheme:
202+
203+
```
204+
for row in enumerate(tile_rows):
205+
mem_row = base[i0, i1, ..., iN + row * stride]
206+
for col in enumerate(tile_cols):
207+
tile[row, col] = mem_row[col]
208+
```
209+
210+
If the `stride` is not provided, then the `base` buffer must be at least
211+
2-dimensional, and the `stride` is automatically inferred and corresponds
212+
to the stride of the buffer's second innermost dimension.
213+
214+
The operation is eventually lowered into the "tileloadd" instruction
215+
with the corresponding tile configuration.
216+
217+
With the write memory effect, each `amx.tile_load` operation serves as
218+
a compilation hint to use a separate tile register.
196219

197220
Example:
198221

199222
```mlir
223+
// Tile load from a 2-D memref with implicit stride.
200224
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
225+
226+
// Tile load from a 1-D memref with explicit stride.
227+
%0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
201228
```
202229
}];
203230
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
204-
Variadic<Index>:$indices);
231+
Variadic<Index>:$indices,
232+
Optional<Index>:$stride);
205233
let results = (outs AnyAMXTile:$res);
234+
let builders = [
235+
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
236+
];
206237
let extraClassDeclaration = [{
207238
MemRefType getMemRefType() {
208239
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
219250
const ::mlir::LLVMTypeConverter &typeConverter,
220251
::mlir::RewriterBase &rewriter);
221252
}];
222-
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
223-
"type($base) `into` qualified(type($res))";
253+
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
254+
"`:` type($base) `into` qualified(type($res))";
224255
let hasVerifier = 1;
225256
}
226257

227258
def TileStoreOp : AMX_Op<"tile_store", [
228-
AMXIntrinsicOpInterface
259+
AMXIntrinsicOpInterface,
260+
AttrSizedOperandSegments
229261
]> {
230262
let summary = "tile store operation";
231263
let description = [{
232-
Stores a tile to memory defined by a base and indices, with the
233-
shape defined by the 2-dim vector type of the value. This is
234-
eventually lowered into the "tilestored" instruction with the
235-
corresponding tile configuration.
264+
Stores a tile to memory defined by a `base` and `indices`, with the
265+
shape defined by the 2-dim vector type of the value.
266+
The tile's rows are written contiguously to the buffer starting at
267+
the `base`. For each tile row, the `base` is incremented by `stride`
268+
number of elements.
269+
270+
The tile is stored using the following indexing scheme:
271+
272+
```
273+
for row in enumerate(tile_rows):
274+
mem_row = base[i0, i1, ..., iN + row * stride]
275+
for col in enumerate(tile_cols):
276+
mem_row[col] = tile[row, col]
277+
```
278+
279+
If the `stride` is not provided, then the `base` buffer must be at least
280+
2-dimensional, and the `stride` is automatically inferred and corresponds
281+
to the stride of the buffer's second innermost dimension.
282+
283+
The operation is eventually lowered into the "tilestored" instruction
284+
with the corresponding tile configuration.
236285

237286
Example:
238287

239288
```mlir
289+
// Tile store to a 2-D memref with implicit stride.
240290
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
291+
292+
// Tile store to a 1-D memref with explicit stride.
293+
amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
241294
```
242295
}];
243296
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
244297
Variadic<Index>:$indices,
245-
AnyAMXTile:$val);
298+
AnyAMXTile:$val,
299+
Optional<Index>:$stride);
300+
let builders = [
301+
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
302+
];
246303
let extraClassDeclaration = [{
247304
MemRefType getMemRefType() {
248305
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
259316
const ::mlir::LLVMTypeConverter &typeConverter,
260317
::mlir::RewriterBase &rewriter);
261318
}];
262-
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
263-
"type($base) `,` qualified(type($val))";
319+
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
320+
"attr-dict `:` type($base) `,` qualified(type($val))";
264321
let hasVerifier = 1;
265322
}
266323

@@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
276333
let description = [{
277334
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
278335
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
279-
pairs of "bf16"). The operation is eventually lowered into the
280-
"tdpbf16ps" instruction with the corresponding tile configuration.
336+
pairs of "bf16").
337+
338+
The operation is eventually lowered into the "tdpbf16ps" instruction with
339+
the corresponding tile configuration.
281340

282341
Example:
283342

@@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
330389
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
331390
combinations (4 bytes packed into dwords in the columns of both the
332391
source operand tiles; the zero or sign extension is specified with
333-
the attributes and default to sign extended). The operation is eventually
334-
lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
335-
instructions with the corresponding tile configuration.
392+
the attributes and default to sign extended).
393+
394+
The operation is eventually lowered into one of the "tdpbssd",
395+
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
396+
tile configuration.
336397

337398
Example:
338399

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
8080
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
8181
}
8282

83+
/// Returns stride expressed in number of bytes for the given `elementStride`
84+
/// stride encoded in number of elements of the type `mType`.
85+
static Value computeStrideInBytes(Location loc, MemRefType mType,
86+
Value elementStride, RewriterBase &rewriter) {
87+
Type llvmInt64Type = rewriter.getIntegerType(64);
88+
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
89+
auto attr = rewriter.getI64IntegerAttr(bytes);
90+
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
91+
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
92+
.getResult();
93+
}
94+
8395
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
8496
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
85-
static Value getStride(Location loc, MemRefType mType, Value base,
86-
RewriterBase &rewriter) {
97+
static Value inferStride(Location loc, MemRefType mType, Value base,
98+
RewriterBase &rewriter) {
8799
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
88100
int64_t preLast = mType.getRank() - 2;
89101
Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
94106
if (strides[preLast] == ShapedType::kDynamic) {
95107
// Dynamic stride needs code to compute the stride at runtime.
96108
MemRefDescriptor memrefDescriptor(base);
97-
auto attr = rewriter.getI64IntegerAttr(bytes);
98-
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
99-
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
100-
memrefDescriptor.stride(rewriter, loc, preLast))
101-
.getResult();
109+
return computeStrideInBytes(
110+
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
102111
}
103112
// Use direct constant for static stride.
104113
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
117126
return getTileSizes(getLoc(), getTileType(), rewriter);
118127
}
119128

120-
LogicalResult amx::TileLoadOp::verify() {
121-
MemRefType memrefTy = getMemRefType();
129+
template <typename OpTy,
130+
typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
131+
std::is_same_v<OpTy, amx::TileStoreOp>>>
132+
static LogicalResult tileTransferVerifier(OpTy op) {
133+
MemRefType memrefTy = op.getMemRefType();
122134
unsigned rank = memrefTy.getRank();
123-
if (rank < 2)
124-
return emitOpError("requires at least 2D memref");
125-
if (getIndices().size() != rank)
126-
return emitOpError("requires ") << rank << " indices";
127-
SmallVector<int64_t> strides;
128-
int64_t offset;
129-
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
130-
strides.back() != 1)
131-
return emitOpError("requires memref with unit innermost stride");
132-
return verifyTileSize(*this, getTileType());
135+
if (op.getIndices().size() != rank)
136+
return op.emitOpError("requires ") << rank << " indices";
137+
138+
if (failed(verifyTileSize(op, op.getTileType())))
139+
return failure();
140+
141+
// Validate basic buffer properties when the stride is implicit.
142+
if (!op.getStride()) {
143+
if (rank < 2)
144+
return op.emitOpError("requires at least 2D memref");
145+
SmallVector<int64_t> strides;
146+
int64_t offset;
147+
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
148+
strides.back() != 1)
149+
return op.emitOpError("requires memref with unit innermost stride");
150+
}
151+
152+
return success();
153+
}
154+
155+
void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
156+
Value base, ValueRange indices) {
157+
build(builder, state, res, base, indices, /*stride=*/nullptr);
133158
}
134159

160+
LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
161+
135162
SmallVector<Value>
136163
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
137164
const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
144171
intrinsicOperands.push_back(
145172
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
146173
adaptor.getBase(), adaptor.getIndices()));
147-
intrinsicOperands.push_back(
148-
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
174+
if (Value stride = adaptor.getStride())
175+
intrinsicOperands.push_back(
176+
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
177+
else
178+
intrinsicOperands.push_back(
179+
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
149180

150181
return intrinsicOperands;
151182
}
152183

153-
LogicalResult amx::TileStoreOp::verify() {
154-
MemRefType memrefTy = getMemRefType();
155-
unsigned rank = memrefTy.getRank();
156-
if (rank < 2)
157-
return emitOpError("requires at least 2D memref");
158-
if (getIndices().size() != rank)
159-
return emitOpError("requires ") << rank << " indices";
160-
SmallVector<int64_t> strides;
161-
int64_t offset;
162-
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
163-
strides.back() != 1)
164-
return emitOpError("requires memref with unit innermost stride");
165-
return verifyTileSize(*this, getTileType());
184+
void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
185+
Value base, ValueRange indices, Value val) {
186+
build(builder, state, base, indices, val, /*stride=*/nullptr);
166187
}
167188

189+
LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
190+
168191
SmallVector<Value>
169192
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
170193
const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
177200
intrinsicOperands.push_back(
178201
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
179202
adaptor.getBase(), adaptor.getIndices()));
180-
intrinsicOperands.push_back(
181-
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
203+
if (Value stride = adaptor.getStride())
204+
intrinsicOperands.push_back(
205+
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
206+
else
207+
intrinsicOperands.push_back(
208+
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
182209
intrinsicOperands.push_back(adaptor.getVal());
183210

184211
return intrinsicOperands;

0 commit comments

Comments
 (0)