Skip to content

Conversation

@jtuyls
Copy link
Contributor

@jtuyls jtuyls commented Jan 5, 2026

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2026

@llvm/pr-subscribers-mlir

Author: Jorn Tuyls (jtuyls)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/174477.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+3-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..45122788bd2d4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType,
-      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+                                ["reifyDimOfResult"]>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7bc6ae5f21e8b..24089f4370c8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -606,6 +606,17 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
 
+FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
+                                                            int resultIndex,
+                                                            int dim) {
+  assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
+  Value source = getMemref();
+  auto sourceType = cast<MemRefType>(source.getType());
+  if (sourceType.isDynamicDim(dim))
+    return OpFoldResult(builder.createOrFold<DimOp>(getLoc(), source, dim));
+  return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
+}
+
 //===----------------------------------------------------------------------===//
 // DistinctObjectsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..374e47fb34b48 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
   }
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
+// CHECK-LABEL: func @dim_of_assume_alignment_static(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
+//  CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[C2]], %[[C3]] : index, index
+func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
+  %d0 = memref.dim %0, %c0 : memref<2x3xf32>
+  %d1 = memref.dim %0, %c1 : memref<2x3xf32>
+  return %d0, %d1 : index, index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
+// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
+//  CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
+//       CHECK:   return %[[C4]], %[[D1]] : index, index
+func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
+  %d0 = memref.dim %0, %c0 : memref<4x?xf32>
+  %d1 = memref.dim %0, %c1 : memref<4x?xf32>
+  return %d0, %d1 : index, index
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2026

@llvm/pr-subscribers-mlir-memref

Author: Jorn Tuyls (jtuyls)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/174477.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+3-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..45122788bd2d4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType,
-      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+                                ["reifyDimOfResult"]>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7bc6ae5f21e8b..24089f4370c8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -606,6 +606,17 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
 
+FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
+                                                            int resultIndex,
+                                                            int dim) {
+  assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
+  Value source = getMemref();
+  auto sourceType = cast<MemRefType>(source.getType());
+  if (sourceType.isDynamicDim(dim))
+    return OpFoldResult(builder.createOrFold<DimOp>(getLoc(), source, dim));
+  return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
+}
+
 //===----------------------------------------------------------------------===//
 // DistinctObjectsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..374e47fb34b48 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
   }
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
+// CHECK-LABEL: func @dim_of_assume_alignment_static(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
+//  CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[C2]], %[[C3]] : index, index
+func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
+  %d0 = memref.dim %0, %c0 : memref<2x3xf32>
+  %d1 = memref.dim %0, %c1 : memref<2x3xf32>
+  return %d0, %d1 : index, index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
+// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
+//  CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
+//       CHECK:   return %[[C4]], %[[D1]] : index, index
+func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
+  %d0 = memref.dim %0, %c0 : memref<4x?xf32>
+  %d1 = memref.dim %0, %c1 : memref<4x?xf32>
+  return %d0, %d1 : index, index
+}

kuhar pushed a commit that referenced this pull request Jan 6, 2026
…174548)

After #174477, I found similar
logic that can be replaced by `memref::getMixedSize` in the
FatRawBufferCastOp dimension reification function.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 6, 2026
…eification (#174548)

After llvm/llvm-project#174477, I found similar
logic that can be replaced by `memref::getMixedSize` in the
FatRawBufferCastOp dimension reification function.
@Abhishek-Varma Abhishek-Varma merged commit c85b8ff into llvm:main Jan 7, 2026
10 checks passed
@jtuyls jtuyls deleted the reify-assume-dim branch January 7, 2026 08:30
navaneethshan pushed a commit to qualcomm/cpullvm-toolchain that referenced this pull request Jan 8, 2026
…(#174548)

After llvm/llvm-project#174477, I found similar
logic that can be replaced by `memref::getMixedSize` in the
FatRawBufferCastOp dimension reification function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants