-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][amx] Optional stride for tile load and store #159569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-amx @llvm/pr-subscribers-mlir-llvm Author: Adam Siemieniuk (adam-smnk) ChangesAdds an optional stride argument to The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior. Explicit stride allows greater flexibility in terms of the base buffer shapes and allows different read and write memory patterns. Operations documentation is expanded to make ops easier to use. Patch is 24.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159569.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 1236fede4d88b..cace63d32fd80 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
- vector type of the result. This is eventually lowered into the
- "tilezero" instruction with the corresponding tile configuration.
- With memory-effects, each "tilezero" operation serves as a compilation
- hint to use a separate tile register.
+ vector type of the result.
+
+ The operation is eventually lowered into the "tilezero" instruction
+ with the corresponding tile configuration.
+
+ With the write memory effect, each `amx.tile_zero` operation serves as
+ a compilation hint to use a separate tile register.
Example:
@@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [
def TileLoadOp : AMX_Op<"tile_load", [
AMXIntrinsicOpInterface,
- MemoryEffects<[MemWrite]>
+ MemoryEffects<[MemWrite]>,
+ AttrSizedOperandSegments
]> {
let summary = "tile load operation";
let description = [{
- Loads a tile from memory defined by a base and indices, with the
- shape defined by the 2-dim vector type of the result. This is
- eventually lowered into the "tileloadd" instruction with the
- corresponding tile configuration. With memory-effects, each "tileload"
- operation serves as a compilation hint to use a separate tile register.
+ Loads a tile from memory defined by a `base` and `indices`, with the
+ shape defined by the 2-dim vector type of the result.
+ The tile's rows are populated by reading contiguous elements starting
+ at the `base`. For each tile row, the `base` is incremented by `stride`
+ number of elements.
+
+ The tile is loaded using the following indexing scheme:
+
+ ```
+ for row in enumerate(tile_rows):
+ mem_row = base[i0, i1, ..., iN + row * stride]
+ for col in enumerate(tile_cols):
+ tile[row, col] = mem_row[col]
+ ```
+
+ If the `stride` is not provided, then the `base` buffer must be at least
+ 2-dimensional, and the `stride` is automatically inferred and corresponds
+ to the stride of the buffer's second innermost dimension.
+
+ The operation is eventually lowered into the "tileloadd" instruction
+ with the corresponding tile configuration.
+
+ With the write memory effect, each `amx.tile_load` operation serves as
+ a compilation hint to use a separate tile register.
Example:
```mlir
+ // Tile load from a 2-D memref with implicit stride.
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+
+ // Tile load from a 1-D memref with explicit stride.
+ %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
- Variadic<Index>:$indices);
+ Variadic<Index>:$indices,
+ Optional<Index>:$stride);
let results = (outs AnyAMXTile:$res);
+ let builders = [
+ OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
+ ];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
- let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
- "type($base) `into` qualified(type($res))";
+ let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
+ "`:` type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}
def TileStoreOp : AMX_Op<"tile_store", [
- AMXIntrinsicOpInterface
+ AMXIntrinsicOpInterface,
+ AttrSizedOperandSegments
]> {
let summary = "tile store operation";
let description = [{
- Stores a tile to memory defined by a base and indices, with the
- shape defined by the 2-dim vector type of the value. This is
- eventually lowered into the "tilestored" instruction with the
- corresponding tile configuration.
+ Stores a tile to memory defined by a `base` and `indices`, with the
+ shape defined by the 2-dim vector type of the value.
+ The tile's rows are written contiguously to the buffer starting at
+ the `base`. For each tile row, the `base` is incremented by `stride`
+ number of elements.
+
+ The tile is stored using the following indexing scheme:
+
+ ```
+ for row in enumerate(tile_rows):
+ mem_row = base[i0, i1, ..., iN + row * stride]
+ for col in enumerate(tile_cols):
+ mem_row[col] = tile[row, col]
+ ```
+
+ If the `stride` is not provided, then the `base` buffer must be at least
+ 2-dimensional, and the `stride` is automatically inferred and corresponds
+ to the stride of the buffer's second innermost dimension.
+
+ The operation is eventually lowered into the "tilestored" instruction
+ with the corresponding tile configuration.
Example:
```mlir
+ // Tile store to a 2-D memref with implicit stride.
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
+
+ // Tile store to a 1-D memref with explicit stride.
+ amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- AnyAMXTile:$val);
+ AnyAMXTile:$val,
+ Optional<Index>:$stride);
+ let builders = [
+ OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
+ ];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
- "type($base) `,` qualified(type($val))";
+ let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
+ "attr-dict `:` type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}
@@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
- pairs of "bf16"). The operation is eventually lowered into the
- "tdpbf16ps" instruction with the corresponding tile configuration.
+ pairs of "bf16").
+
+ The operation is eventually lowered into the "tdpbf16ps" instruction with
+ the corresponding tile configuration.
Example:
@@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
- the attributes and default to sign extended). The operation is eventually
- lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
- instructions with the corresponding tile configuration.
+ the attributes and default to sign extended).
+
+ The operation is eventually lowered into one of the "tdpbssd",
+ "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
+ tile configuration.
Example:
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 68990ef0dc0c3..d9c097c9a3c6f 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+ Value elementStride, RewriterBase &rewriter) {
+ Type llvmInt64Type = rewriter.getIntegerType(64);
+ unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+ auto attr = rewriter.getI64IntegerAttr(bytes);
+ Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+ return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+ .getResult();
+}
+
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value getStride(Location loc, MemRefType mType, Value base,
- RewriterBase &rewriter) {
+static Value inferStride(Location loc, MemRefType mType, Value base,
+ RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
- auto attr = rewriter.getI64IntegerAttr(bytes);
- Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
- return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
- memrefDescriptor.stride(rewriter, loc, preLast))
- .getResult();
+ return computeStrideInBytes(
+ loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
return getTileSizes(getLoc(), getTileType(), rewriter);
}
-LogicalResult amx::TileLoadOp::verify() {
- MemRefType memrefTy = getMemRefType();
+template <typename OpTy,
+ typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
+ std::is_same_v<OpTy, amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+ MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+ if (op.getIndices().size() != rank)
+ return op.emitOpError("requires ") << rank << " indices";
+
+ if (failed(verifyTileSize(op, op.getTileType())))
+ return failure();
+
+ // Validate basic buffer properties when the stride is implicit.
+ if (!op.getStride()) {
+ if (rank < 2)
+ return op.emitOpError("requires at least 2D memref");
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+ strides.back() != 1)
+ return op.emitOpError("requires memref with unit innermost stride");
+ }
+
+ return success();
+}
+
+void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
+ Value base, ValueRange indices) {
+ build(builder, state, res, base, indices, /*stride=*/nullptr);
}
+LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
return intrinsicOperands;
}
-LogicalResult amx::TileStoreOp::verify() {
- MemRefType memrefTy = getMemRefType();
- unsigned rank = memrefTy.getRank();
- if (rank < 2)
- return emitOpError("requires at least 2D memref");
- if (getIndices().size() != rank)
- return emitOpError("requires ") << rank << " indices";
- SmallVector<int64_t> strides;
- int64_t offset;
- if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
- strides.back() != 1)
- return emitOpError("requires memref with unit innermost stride");
- return verifyTileSize(*this, getTileType());
+void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+ Value base, ValueRange indices, Value val) {
+ build(builder, state, base, indices, val, /*stride=*/nullptr);
}
+LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
+
SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
- intrinsicOperands.push_back(
- getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+ if (Value stride = adaptor.getStride())
+ intrinsicOperands.push_back(
+ computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+ else
+ intrinsicOperands.push_back(
+ inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());
return intrinsicOperands;
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 7e562b00a46a9..a109f42e9dea3 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -60,30 +60,74 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
return
}
-// CHECK-LABEL: strides(
-// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
-// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
-// CHECK: llvm.mlir.constant(2 : i64) : i64
+/// Intrinsics require stride in number of bytes.
+// CHECK-LABEL: strides_implicit(
+// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]]
+// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]]
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
-// CHECK: %[[STRIDE_1:.+]] = llvm.mul
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
-// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
-// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
-// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[LOAD_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[LOAD_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE]], %[[LOAD_BUF_STRIDE]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]]
+// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]]
+// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]]
// CHECK: llvm.extractvalue %{{.+}}[4, 0]
-// CHECK: %[[STRIDE_2:.+]] = llvm.mul
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
-func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+// CHECK: %[[STORE_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STORE_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE]], %[[STORE_BUF_STRIDE]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]]
+func.func @strides_implicit(%arg0: memref<16x32xi8>,
+ %arg1: memref<32x32xbf16, strided<[64, 1]>>,
+ %arg2: memref<16x32xf32, strided<[?, 1]>>) {
%0 = arith.constant 0 : index
- %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
- %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
- %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
- amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
- amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
- amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
+ %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8>
+ %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
+ %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32>
+ amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8>
+ amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
+ amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32>
+ return
+}
+
+/// Intrinsics require stride in number of bytes.
+// CHECK-LABEL: strides_explicit(
+// CHECK-SAME: %[[STRIDE:.+]]: index
+// CHECK-DAG: %[[STRIDE_I64:.+]] = builtin.unrealized_conversion_cast %[[STRIDE]] : index to i64
+// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG: %[[C64_I64:.+]] = builtin.unrealized_conversion_cast %[[C64]] : index to i64
+// CHECK: %[[LOAD_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_1]], %[[STRIDE_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]]
+// CHECK: %[[LOAD_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_2]], %[[STRIDE_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]]
+// CHECK: %[[LOAD_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_3]], %[[C64_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]]
+// CHECK: %[[STORE_STR...
[truncated]
|
cc: @arun-thmn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing tests for tile_zero
. I know they're the same, but semantics does change and it's good to make sure it remains consistent.
Nothing changes for Happy to add coverage for anything else if you spotted any gaps. |
My bad, you're right! |
Adds an optional stride argument to `amx.tile_load` and `amx.tile_store` operations. The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior. Explicit stride allows greater flexibility in terms of the base buffer shape (enables usage of 1D memrefs) and allows different read and write memory patterns. When stride is not provided, it is inferred from the buffer shape. Operations documentation is expanded to make ops easier to use.
d7e3f13
to
22bb27e
Compare
Rebased on main + bump for review |
Adds an optional stride argument to
amx.tile_load
andamx.tile_store
operations.The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior.
Explicit stride allows greater flexibility in terms of the base buffer shapes and allows different read and write memory patterns.
When stride is not provided, it is inferred from the buffer shape as before.
Operations documentation is expanded to make ops easier to use.