Skip to content

[MLIR] Legalize certain vector.transfer_read ops of scalable vectors #143146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: users/momchil-velikov/memref-contig-slice
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

THis patch add a transform of transfer_read operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

THis patch add a transform of transfer_read operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.


Full diff: https://github.com/llvm/llvm-project/pull/143146.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp (+109-1)
  • (added) mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir (+226)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir (+72)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
   }
 };
 
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+///   {in_bounds = [false, true, true, true]}
+///   : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+///   : memref<?x?x2x8xi8> into memref<?x?xi8>
+/// %0 = vector.transfer_read  %collapse_shape[%i, %j], %c0_i8
+///   {in_bounds = [false, true]}
+///   : memref<?x?xi8>, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!readOp.getPermutationMap().isMinorIdentity())
+      return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+    // We handle transfers of vectors with rank >= 2 and a single scalable
+    // dimension.
+    VectorType origVT = readOp.getVectorType();
+    ArrayRef<bool> origScalableDims = origVT.getScalableDims();
+    const int64_t origVRank = origVT.getRank();
+    if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+      return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+    // Number of trailing dimensions to collapse, including the scalable
+    // dimension.  Nothing to do if the single scalable dimension is already the
+    // last one.
+    const int64_t numCollapseDims = std::distance(
+        llvm::find(origScalableDims, true), origScalableDims.end());
+    if (numCollapseDims < 2)
+      return rewriter.notifyMatchFailure(readOp,
+                                         "scalable dimension is trailing");
+
+    // We want a simple memref (not a tensor) with contiguous elements for at
+    // least all the trailing dimensions up to and including the scalable one.
+    auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
+    if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+      return rewriter.notifyMatchFailure(
+          readOp, "non-contiguous memref dimensions to collapse");
+
+    // The collapsed dimensions (excluding the scalable one) of the vector and
+    // the memref must match and the corresponding indices must be in-bounds (it
+    // follows these indices would be zero). This guarantees that the operation
+    // transfers a contiguous block.
+    if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+                     origVT.getShape().take_back(numCollapseDims - 1)))
+      return rewriter.notifyMatchFailure(
+          readOp, "memref and vector dimensions do not match");
+
+    SmallVector<bool> origInBounds = readOp.getInBoundsValues();
+    if (!llvm::all_of(
+            ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
+            [](bool v) { return v; }))
+      return rewriter.notifyMatchFailure(readOp,
+                                         "out-if-bounds index to collapse");
+
+    // Collapse the trailing dimensions of the memref.
+    SmallVector<ReassociationIndices> reassoc;
+    for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+      reassoc.push_back({i});
+    for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
+         ++i)
+      reassoc.back().push_back(i);
+    if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
+      return failure();
+    Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
+        readOp.getLoc(), readOp.getBase(), reassoc);
+
+    // Get a vector type with collapsed trailing dimensions.
+    SmallVector<int64_t> shape(origVT.getShape());
+    for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
+      shape[origVRank - numCollapseDims] *= shape[i];
+    shape.pop_back_n(numCollapseDims - 1);
+    auto collapsedVT =
+        VectorType::get(shape, origVT.getElementType(),
+                        origScalableDims.drop_back(numCollapseDims - 1));
+
+    // Drop the extra (zero) indices.
+    auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
+
+    // Create the new `transfer_read`.
+    auto newReadOp = rewriter.create<vector::TransferReadOp>(
+        readOp.getLoc(), collapsedVT, collapsedMem, indices,
+        ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
+
+    // Cast back to the orignal vector type.
+    auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
+                                                            origVT, newReadOp);
+
+    rewriter.replaceOp(readOp, toOrigShape);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
@@ -306,7 +413,8 @@ void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
                LegalizeSVEMaskAllocation<memref::AllocaOp>,
                LegalizeSVEMaskAllocation<memref::AllocOp>,
                LegalizeSVEMaskTypeCastConversion,
-               LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
+               LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
+               LegalizeTransferRead>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
new file mode 100644
index 0000000000000..d12a2c11bbdba
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -0,0 +1,226 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL:       @test_base_case
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK:               %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1], [2, 3]]
+// CHECK-SAME:            : memref<?x?x?x8xi8> into memref<?x?x?xi8>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?x?xi8>, vector<[32]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_using_strided_layout
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1], [2, 3]]
+// CHECK-SAME:            : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
+// CHECK-SAME:              memref<?x?x?xi8, strided<[?, ?, 1]>>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_3d_vector
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2, 3]]
+// CHECK-SAME:            : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME:              memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT:          %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME:            : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
+// CHECK-NEXT:          %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT:          return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
+
+  return %A : vector<[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL:       @test_4d_vector
+// CHECK-SAME:          %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK:               %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}:   [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME:             memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT:         %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME:           : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
+// CHECK-NEXT:         %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT:         return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
+
+  return %A : vector<2x[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x8xi8>
+
+  return %A : vector<8x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref<?x?x?x8xi8>, vector<[8]xi8>
+
+  return %A : vector<[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_1
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x[8]xi8>
+
+  return %A : vector<8x[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_type_not_supported
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]x[8]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x?x8xi8>, vector<[8]x[8]x8xi8>
+
+  return %A : vector<[8]x[8]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_non_mem
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : tensor<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_0
+// CHECK-NOT: memref.collapse
+
+#s3 = strided<[?, ?, 16, 1]>
+
+func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_1
+// CHECK-NOT: memref.collapse
+
+#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
+
+func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_read_strided_vec
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
+
+  return %A : vector<[4]x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_bcast_transp
+// CHECK-NOT: memref.collapse
+
+#perm = affine_map<(i, j, k, p) -> (k, 0)>
+
+func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  return %A : vector<[4]x8xi8>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
new file mode 100644
index 0000000000000..7f68d8f7ab848
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
@@ -0,0 +1,72 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata    --lower-affine --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func private @printVec(%v : vector<[32]xi8>) {
+  %v0 = vector.scalable.extract %v[0] : vector<[16]xi8> from vector<[32]xi8>
+  %v1 = vector.scalable.extract %v[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %v0 : vector<[16]xi8>
+  vector.print %v1 : vector<[16]xi8>
+  return
+}
+
+func.func @transfer_read_scalable_not_rightmost(%vs : i32, %M : memref<?x?x?x8xi8>) {
+  func.call @setArmVLBits(%vs) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+  %A = vector.transfer_read %M[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+  %B = vector.shape_cast %A : vector<[4]x8xi8> to vector<[32]xi8>
+  func.call @printVec(%B) : (vector<[32]xi8>) -> ()
+
+  return
+}
+
+func.func @main() {
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+  %A0_cst = arith.constant dense<[[11, 12, 13, 14, 15, 16, 17, 18],
+                                  [21, 22, 23, 24, 25, 26, 27, 28],
+                                  [31, 32, 33, 34, 35, 36, 37, 38],
+                                  [41, 42, 43, 44, 45, 46, 47, 48]]> : vector<4x8xi8>
+
+  %A1_cst = arith.constant dense<[[51, 52, 53, 54, 55, 56, 57, 58],
+                                  [61, 62, 63, 64, 65, 66, 67, 68],
+                                  [71, 72, 73, 74, 75, 76, 77, 78],
+                                  [81, 82, 83, 84, 85, 86, 87, 88]]> : vector<4x8xi8>
+
+  %M = memref.alloca() : memref<1x2x4x8xi8>
+  vector.transfer_write %A0_cst, %M[%c0, %c0, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+  vector.transfer_write %A1_cst, %M[%c0, %c1, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+
+  %MM = memref.cast %M : memref<1x2x4x8xi8> to memref<?x?x?x8xi8>
+
+// CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 )
+// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+  %c128 = arith.constant 128 : i32
+  func.call @transfer_read_scalable_not_rightmost(%c128, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
+  %c256 = arith.constant 256 : i32
+  func.call @transfer_read_scalable_not_rightmost(%c256, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+  return
+}

Copy link

github-actions bot commented Jun 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from f60e73d to 1210d59 Compare June 6, 2025 15:27
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch 2 times, most recently from 4d13aa2 to 413d9dc Compare June 9, 2025 16:42
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/memref-contig-slice branch from 1210d59 to 3b17c94 Compare June 9, 2025 16:42
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch from 413d9dc to 5496f97 Compare June 13, 2025 16:42
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, Momchil - thank you!

I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in MLIR's Testing Guide a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better.

For additional context, this RFC provides some of the rationale behind that approach.

Also - what about memrefs with dynamic dimensions?

VectorType origVT = readOp.getVectorType();
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
const int64_t origVRank = origVT.getRank();
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] getNumScalableDims would be more canonical then llvm::count

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +342 to +352
if (!readOp.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would supporting non-identity be a problem? It would be good to add a comment, either:

  • TODO: We haven't required this, so leaving for later. or
  • "Too complex because of , disabling".

Any hint for future developers would be helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 345 to 346
// We handle transfers of vectors with rank >= 2 and a single scalable
// dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] It would be helpful to add why:

  • Don't need to worry about 1D, that's supported by default.
  • More than 1 scalable dims are tricky (how to collapse e.g. vscale * vscale?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment added.

Comment on lines 369 to 372
// The collapsed dimensions (excluding the scalable one) of the vector and
// the memref must match and the corresponding indices must be in-bounds (it
// follows these indices would be zero). This guarantees that the operation
// transfers a contiguous block.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The collapsed dimensions (excluding the scalable one) of the vector and

// the memref must match

What about dynamic dim sizes in the memref? If that's not supported, is there a test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part wasn't tested at all. Test cases added.

ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
[](bool v) { return v; }))
return rewriter.notifyMatchFailure(readOp,
"out-if-bounds index to collapse");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, it's not really index that's out-of-bounds, but the corresponding memory access. So, index could be in-bounds, but we might be reading "more" than there's available to read (starting at that index). For example:

vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
Suggested change
"out-if-bounds index to collapse");
"out-of-bounds index to collapse");

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


#s3 = strided<[?, ?, 16, 1]>

func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Avoid "magic" suffixes likes _0.

Suggested change
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
func.func @negative_test_discont_mem_due_to_strides(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>

func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Same as above.

Suggested change
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
func.func @negative_test_discontig_mem_due_to_maps(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test removed, no need to test here all the possible ways a memref could be discontinuous.

Comment on lines 203 to 199
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>

return %A : vector<[4]x4xi8>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this a negative test? It says "strided vec", but I'm not sure what you mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's garbage, deleted.

Comment on lines 233 to 255
func.func @negative_test_vector_mask(
%i : index, %j : index,
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {

%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.mask %mask {
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
} : vector<[4]x8xi1> -> vector<[4]x8xi8>

return %A : vector<[4]x8xi8>
}

// -----

// CHECK-LABEL: @negative_test_mask_operand
// CHECK-NOT: memref.collapse

func.func @negative_test_mask_operand(
%i : index, %j : index,
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> {

%c0 = arith.constant 0 : index
%c0_i8 = arith.constant 0 : i8

%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>

return %A : vector<[4]x8xi8>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past, I would differentiate these are:

  • "masked" (vector.mask {vector. transfer_read}), vs
  • "with_mask" (vector.transfer_read %mask)

Would you mind following similar convention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mixing fixed-width and scalable vectors. Lets avoid that until we understand better how to mix VLA + VLS programming.

THis patch add a transform  of `transfer_read` operation to change the
vector type to one that can be mapped to an LLVM type. This is done by
collapsing trailing dimensions so we obtain a vector type with a single
scalable dimension in the rightmost position.
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/legalize-scalable-transfer-read branch from 5496f97 to e422213 Compare June 20, 2025 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants