Skip to content

Conversation

adam-smnk
Copy link
Contributor

Adds an optional stride argument to amx.tile_load and amx.tile_store operations.

The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior.

Explicit stride allows greater flexibility in terms of the base buffer shapes and allows different read and write memory patterns.
When stride is not provided, it is inferred from the buffer shape as before.

Operations documentation is expanded to make ops easier to use.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir-amx
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds an optional stride argument to amx.tile_load and amx.tile_store operations.

The stride argument aligns ops closer to the hardware intrinsics. However, stride remains optional to preserve current op behavior.

Explicit stride allows greater flexibility in terms of the base buffer shapes and allows different read and write memory patterns.
When stride is not provided, it is inferred from the buffer shape as before.

Operations documentation is expanded to make ops easier to use.


Patch is 24.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159569.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMX/AMX.td (+87-26)
  • (modified) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (+63-36)
  • (modified) mlir/test/Dialect/AMX/legalize-for-llvm.mlir (+66-22)
  • (modified) mlir/test/Dialect/AMX/roundtrip.mlir (+28)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+13)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 1236fede4d88b..cace63d32fd80 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
   let summary = "tile zero operation";
   let description = [{
     Zeroes the destination tile, with the shape defined by the 2-dim
-    vector type of the result. This is eventually lowered into the
-    "tilezero" instruction with the corresponding tile configuration.
-    With memory-effects, each "tilezero" operation serves as a compilation 
-    hint to use a separate tile register.
+    vector type of the result.
+    
+    The operation is eventually lowered into the "tilezero" instruction
+    with the corresponding tile configuration.
+    
+    With the write memory effect, each `amx.tile_zero` operation serves as
+    a compilation hint to use a separate tile register.
 
     Example:
 
@@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [
 
 def TileLoadOp : AMX_Op<"tile_load", [
     AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>
+    MemoryEffects<[MemWrite]>,
+    AttrSizedOperandSegments
   ]> {
   let summary = "tile load operation";
   let description = [{
-    Loads a tile from memory defined by a base and indices, with the
-    shape defined by the 2-dim vector type of the result. This is
-    eventually lowered into the "tileloadd" instruction with the
-    corresponding tile configuration. With memory-effects, each "tileload" 
-    operation serves as a compilation hint to use a separate tile register.
+    Loads a tile from memory defined by a `base` and `indices`, with the
+    shape defined by the 2-dim vector type of the result.
+    The tile's rows are populated by reading contiguous elements starting
+    at the `base`. For each tile row, the `base` is incremented by `stride`
+    number of elements.
+
+    The tile is loaded using the following indexing scheme:
+
+    ```
+    for row in enumerate(tile_rows):
+      mem_row = base[i0, i1, ..., iN + row * stride]
+      for col in enumerate(tile_cols):
+        tile[row, col] = mem_row[col]
+    ```
+
+    If the `stride` is not provided, then the `base` buffer must be at least
+    2-dimensional, and the `stride` is automatically inferred and corresponds
+    to the stride of the buffer's second innermost dimension.
+
+    The operation is eventually lowered into the "tileloadd" instruction
+    with the corresponding tile configuration.
+
+    With the write memory effect, each `amx.tile_load` operation serves as
+    a compilation hint to use a separate tile register.
 
     Example:
 
     ```mlir
+      // Tile load from a 2-D memref with implicit stride.
       %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
+
+      // Tile load from a 1-D memref with explicit stride.
+      %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices);
+                   Variadic<Index>:$indices,
+                   Optional<Index>:$stride);
   let results = (outs AnyAMXTile:$res);
+  let builders = [
+    OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
+  ];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
         const ::mlir::LLVMTypeConverter &typeConverter,
         ::mlir::RewriterBase &rewriter);
   }];
-  let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
-                       "type($base) `into` qualified(type($res))";
+  let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
+                       "`:` type($base) `into` qualified(type($res))";
   let hasVerifier = 1;
 }
 
 def TileStoreOp : AMX_Op<"tile_store", [
-    AMXIntrinsicOpInterface
+    AMXIntrinsicOpInterface,
+    AttrSizedOperandSegments
   ]> {
   let summary = "tile store operation";
   let description = [{
-    Stores a tile to memory defined by a base and indices, with the
-    shape defined by the 2-dim vector type of the value. This is
-    eventually lowered into the "tilestored" instruction with the
-    corresponding tile configuration.
+    Stores a tile to memory defined by a `base` and `indices`, with the
+    shape defined by the 2-dim vector type of the value.
+    The tile's rows are written contiguously to the buffer starting at
+    the `base`. For each tile row, the `base` is incremented by `stride`
+    number of elements.
+
+    The tile is stored using the following indexing scheme:
+
+    ```
+    for row in enumerate(tile_rows):
+      mem_row = base[i0, i1, ..., iN + row * stride]
+      for col in enumerate(tile_cols):
+        mem_row[col] = tile[row, col]
+    ```
+
+    If the `stride` is not provided, then the `base` buffer must be at least
+    2-dimensional, and the `stride` is automatically inferred and corresponds
+    to the stride of the buffer's second innermost dimension.
+
+    The operation is eventually lowered into the "tilestored" instruction
+    with the corresponding tile configuration.
 
     Example:
 
     ```mlir
+      // Tile store to a 2-D memref with implicit stride.
       amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
+
+      // Tile store to a 1-D memref with explicit stride.
+      amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   AnyAMXTile:$val);
+                   AnyAMXTile:$val,
+                   Optional<Index>:$stride);
+  let builders = [
+    OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
+  ];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
         const ::mlir::LLVMTypeConverter &typeConverter,
         ::mlir::RewriterBase &rewriter);
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
-                       "type($base) `,` qualified(type($val))";
+  let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
+                       "attr-dict `:` type($base) `,` qualified(type($val))";
   let hasVerifier = 1;
 }
 
@@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
   let description = [{
     Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
     into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16"). The operation is eventually lowered into the
-    "tdpbf16ps" instruction with the corresponding tile configuration.
+    pairs of "bf16").
+    
+    The operation is eventually lowered into the "tdpbf16ps" instruction with
+    the corresponding tile configuration.
 
     Example:
 
@@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
     into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
     combinations (4 bytes packed into dwords in the columns of both the
     source operand tiles; the zero or sign extension is specified with
-    the attributes and default to sign extended). The operation is eventually
-    lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
-    instructions with the corresponding tile configuration.
+    the attributes and default to sign extended).
+    
+    The operation is eventually lowered into one of the "tdpbssd",
+    "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
+    tile configuration.
 
     Example:
 
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index 68990ef0dc0c3..d9c097c9a3c6f 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
       LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
 }
 
+/// Returns stride expressed in number of bytes for the given `elementStride`
+/// stride encoded in number of elements of the type `mType`.
+static Value computeStrideInBytes(Location loc, MemRefType mType,
+                                  Value elementStride, RewriterBase &rewriter) {
+  Type llvmInt64Type = rewriter.getIntegerType(64);
+  unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
+  auto attr = rewriter.getI64IntegerAttr(bytes);
+  Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
+  return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
+      .getResult();
+}
+
 /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
 /// shape may "envelop" the actual tile shape, and may be dynamically sized.
-static Value getStride(Location loc, MemRefType mType, Value base,
-                       RewriterBase &rewriter) {
+static Value inferStride(Location loc, MemRefType mType, Value base,
+                         RewriterBase &rewriter) {
   assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
   int64_t preLast = mType.getRank() - 2;
   Type llvmInt64Type = rewriter.getIntegerType(64);
@@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
   if (strides[preLast] == ShapedType::kDynamic) {
     // Dynamic stride needs code to compute the stride at runtime.
     MemRefDescriptor memrefDescriptor(base);
-    auto attr = rewriter.getI64IntegerAttr(bytes);
-    Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
-    return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
-                               memrefDescriptor.stride(rewriter, loc, preLast))
-        .getResult();
+    return computeStrideInBytes(
+        loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
   }
   // Use direct constant for static stride.
   auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
@@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
   return getTileSizes(getLoc(), getTileType(), rewriter);
 }
 
-LogicalResult amx::TileLoadOp::verify() {
-  MemRefType memrefTy = getMemRefType();
+template <typename OpTy,
+          typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
+                                      std::is_same_v<OpTy, amx::TileStoreOp>>>
+static LogicalResult tileTransferVerifier(OpTy op) {
+  MemRefType memrefTy = op.getMemRefType();
   unsigned rank = memrefTy.getRank();
-  if (rank < 2)
-    return emitOpError("requires at least 2D memref");
-  if (getIndices().size() != rank)
-    return emitOpError("requires ") << rank << " indices";
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
-      strides.back() != 1)
-    return emitOpError("requires memref with unit innermost stride");
-  return verifyTileSize(*this, getTileType());
+  if (op.getIndices().size() != rank)
+    return op.emitOpError("requires ") << rank << " indices";
+
+  if (failed(verifyTileSize(op, op.getTileType())))
+    return failure();
+
+  // Validate basic buffer properties when the stride is implicit.
+  if (!op.getStride()) {
+    if (rank < 2)
+      return op.emitOpError("requires at least 2D memref");
+    SmallVector<int64_t> strides;
+    int64_t offset;
+    if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
+        strides.back() != 1)
+      return op.emitOpError("requires memref with unit innermost stride");
+  }
+
+  return success();
+}
+
+void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
+                            Value base, ValueRange indices) {
+  build(builder, state, res, base, indices, /*stride=*/nullptr);
 }
 
+LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }
+
 SmallVector<Value>
 amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
                                       const LLVMTypeConverter &typeConverter,
@@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
   intrinsicOperands.push_back(
       LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
                                  adaptor.getBase(), adaptor.getIndices()));
-  intrinsicOperands.push_back(
-      getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+  if (Value stride = adaptor.getStride())
+    intrinsicOperands.push_back(
+        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+  else
+    intrinsicOperands.push_back(
+        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
 
   return intrinsicOperands;
 }
 
-LogicalResult amx::TileStoreOp::verify() {
-  MemRefType memrefTy = getMemRefType();
-  unsigned rank = memrefTy.getRank();
-  if (rank < 2)
-    return emitOpError("requires at least 2D memref");
-  if (getIndices().size() != rank)
-    return emitOpError("requires ") << rank << " indices";
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
-      strides.back() != 1)
-    return emitOpError("requires memref with unit innermost stride");
-  return verifyTileSize(*this, getTileType());
+void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
+                             Value base, ValueRange indices, Value val) {
+  build(builder, state, base, indices, val, /*stride=*/nullptr);
 }
 
+LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }
+
 SmallVector<Value>
 amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
                                        const LLVMTypeConverter &typeConverter,
@@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
   intrinsicOperands.push_back(
       LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
                                  adaptor.getBase(), adaptor.getIndices()));
-  intrinsicOperands.push_back(
-      getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
+  if (Value stride = adaptor.getStride())
+    intrinsicOperands.push_back(
+        computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
+  else
+    intrinsicOperands.push_back(
+        inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
   intrinsicOperands.push_back(adaptor.getVal());
 
   return intrinsicOperands;
diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
index 7e562b00a46a9..a109f42e9dea3 100644
--- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir
@@ -60,30 +60,74 @@ func.func @mulfp16(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
   return
 }
 
-// CHECK-LABEL: strides(
-// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]]
-// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]]
-// CHECK: llvm.mlir.constant(2 : i64) : i64
+/// Intrinsics require stride in number of bytes.
+// CHECK-LABEL: strides_implicit(
+// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]]
+// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]]
 // CHECK: llvm.extractvalue %{{.+}}[4, 0]
-// CHECK: %[[STRIDE_1:.+]] = llvm.mul
-// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]]
-// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]]
-// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]]
-// CHECK: llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[LOAD_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[LOAD_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE]], %[[LOAD_BUF_STRIDE]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]]
+// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]]
+// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]]
 // CHECK: llvm.extractvalue %{{.+}}[4, 0]
-// CHECK: %[[STRIDE_2:.+]] = llvm.mul
-// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]]
-func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) {
+// CHECK: %[[STORE_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0]
+// CHECK: %[[STORE_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE]], %[[STORE_BUF_STRIDE]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]]
+func.func @strides_implicit(%arg0: memref<16x32xi8>,
+    %arg1: memref<32x32xbf16, strided<[64, 1]>>,
+    %arg2: memref<16x32xf32, strided<[?, 1]>>) {
   %0 = arith.constant 0 : index
-  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16>
-  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
-  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16>
-  amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
-  amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16>
+  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8>
+  %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16>
+  %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32>
+  amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8>
+  amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16>
+  amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32>
+  return
+}
+
+/// Intrinsics require stride in number of bytes.
+// CHECK-LABEL: strides_explicit(
+// CHECK-SAME:    %[[STRIDE:.+]]: index
+// CHECK-DAG: %[[STRIDE_I64:.+]] = builtin.unrealized_conversion_cast %[[STRIDE]] : index to i64
+// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-DAG: %[[C64_I64:.+]] = builtin.unrealized_conversion_cast %[[C64]] : index to i64
+// CHECK: %[[LOAD_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_1]], %[[STRIDE_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]]
+// CHECK: %[[LOAD_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_2]], %[[STRIDE_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]]
+// CHECK: %[[LOAD_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64
+// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_3]], %[[C64_I64]]
+// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]]
+// CHECK: %[[STORE_STR...
[truncated]

@adam-smnk
Copy link
Contributor Author

cc: @arun-thmn

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing tests for tile_zero. I know they're the same, but semantics does change and it's good to make sure it remains consistent.

@adam-smnk
Copy link
Contributor Author

missing tests for tile_zero. I know they're the same, but semantics does change and it's good to make sure it remains consistent.

Nothing changes for tile_zero here. It has no stride semantics as it doesn't interact with memory.
I only tweaked its docs formatting.

Happy to add coverage for anything else if you spotted any gaps.

@rengolin
Copy link
Member

My bad, you're right!

Adds an optional stride argument to `amx.tile_load` and `amx.tile_store`
operations.

The stride argument aligns ops closer to the hardware intrinsics.
However, stride remains optional to preserve current op behavior.

Explicit stride allows greater flexibility in terms of the base buffer
shape (enables usage of 1D memrefs) and allows different read and write
memory patterns.
When stride is not provided, it is inferred from the buffer shape.

Operations documentation is expanded to make ops easier to use.
@adam-smnk
Copy link
Contributor Author

Rebased on main + bump for review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants