Skip to content

Reapply "[AMDGPU] fold memref.subview/expand_shape/collapse_shape into amdgpu.gather_to_lds" #150334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 24, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Jul 23, 2025

This is a reapply of patch #149851. The reapply also fixes a CMake/Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

@llvmbot llvmbot added backend:AMDGPU mlir:gpu mlir mlir:memref bazel "Peripheral" support tier build system: utils/bazel mlir:amdgpu labels Jul 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-amdgpu

Author: Alan Li (lialan)

Changes

This is a reapply of patch #149851. The reapply also fixes a Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150334.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+12)
  • (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+37)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp (+97)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (-91)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+66)
  • (added) mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir (+94)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index cc2f543e79f69..58b9c74b2f8e0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,8 +22,9 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 8d0e6829ab0cc..8664f971cabde 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
     "memref::MemRefDialect"
   ];
 }
+
+def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+  let summary = "Fold memref operations into their parent operations";
+  let description = [{
+    This pass identifies memref operations (subview, expand_shape, collapse_shape)
+    that are sources of `GatherToLDSOp` and attempts to fold the source ops,
+    potentially simplifying the overall operation and improving performance.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 34ad279a07a8b..dd3b3dea6ef26 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
 /// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54ea6c0c..3b0c072ed1217 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  ResolveStridedMetadata.cpp
+  FoldMemRefsOps.cpp
   MaskedloadToLoad.cpp
+  ResolveStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
new file mode 100644
index 0000000000000..a3fdc7ee385ed
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -0,0 +1,97 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+
+struct AmdgpuFoldMemRefOpsPass final
+    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuFoldMemRefOpsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Value memrefSource;
+    SmallVector<Value> sourceIndices;
+    auto foldResult =
+        llvm::TypeSwitch<Operation *, LogicalResult>(
+            op.getSrc().getDefiningOp())
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+              // If the source is a SubViewOp, we can directly rewrite the
+              // GatherToLDSOp.
+              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                  rewriter, loc, subviewOp.getMixedOffsets(),
+                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                  op.getSrcIndices(), sourceIndices);
+              memrefSource = subviewOp.getSource();
+              return success();
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expandShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
+                          sourceIndices, false))) {
+                    return failure();
+                  }
+                  memrefSource = expandShapeOp.getViewSource();
+                  return success();
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapseShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                          sourceIndices))) {
+                    return failure();
+                  }
+                  memrefSource = collapseShapeOp.getViewSource();
+                  return success();
+                })
+            .Default([&](Operation *op) {
+              // If the source is not a SubViewOp, ExpandShapeOp, or
+              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+              return rewriter.notifyMatchFailure(
+                  op,
+                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
+                  "CollapseShapeOp");
+            });
+
+    if (failed(foldResult)) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+                                               op.getDst(), op.getDstIndices(),
+                                               op.getTransferType());
+
+    return success();
+  }
+};
+
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit) {
+  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
+} // namespace mlir::amdgpu
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188af9129..24da447ad7685 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-///    : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-///          memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
-  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
-  // Traverse all reassociation groups to determine the appropriate indices
-  // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-    if (groupSize == 1) {
-      sourceIndices.push_back(indices[group[0]]);
-      continue;
-    }
-    SmallVector<OpFoldResult> groupBasis =
-        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
-    SmallVector<Value> groupIndices =
-        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
-    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
-    sourceIndices.push_back(collapsedIndex);
-  }
-  return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-///    : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-///          memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
-  // Note: collapse_shape requires a strided memref, we can do this.
-  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-      loc, collapseShapeOp.getSrc());
-  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
-  for (auto [index, group] :
-       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-
-    if (groupSize == 1) {
-      sourceIndices.push_back(index);
-      continue;
-    }
-
-    SmallVector<OpFoldResult> basis =
-        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
-    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        loc, index, basis, /*hasOuterBound=*/true);
-    llvm::append_range(sourceIndices, delinearize.getResults());
-  }
-  if (collapseShapeOp.getReassociationIndices().empty()) {
-    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    int64_t srcRank =
-        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
-    for (int64_t i = 0; i < srcRank; i++) {
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-    }
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cfc74708..97fe3cb5b4705 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+  // Traverse all reassociation groups to determine the appropriate indices
+  // corresponding to each one of them post op folding.
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
+    }
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
+  }
+  return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
+    }
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    int64_t srcRank =
+        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+    for (int64_t i = 0; i < srcRank; i++) {
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    }
+  }
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
new file mode 100644
index 0000000000000..57afa127c9da8
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt --amdgpu-fold-memrefs-ops --split-input-file %s | FileCheck %s
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_subview_folding
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]], %[[ARG1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @subview_folding_offset
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+  // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[32, 64][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1], offset: 4160>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+
+  %alloc = memre...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir

Author: Alan Li (lialan)

Changes

This is a reapply of patch #149851. The reapply also fixes a Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150334.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+12)
  • (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+37)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp (+97)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (-91)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+66)
  • (added) mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir (+94)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index cc2f543e79f69..58b9c74b2f8e0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,8 +22,9 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 8d0e6829ab0cc..8664f971cabde 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
     "memref::MemRefDialect"
   ];
 }
+
+def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+  let summary = "Fold memref operations into their parent operations";
+  let description = [{
+    This pass identifies memref operations (subview, expand_shape, collapse_shape)
+    that are sources of `GatherToLDSOp` and attempts to fold the source ops,
+    potentially simplifying the overall operation and improving performance.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 34ad279a07a8b..dd3b3dea6ef26 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
 /// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54ea6c0c..3b0c072ed1217 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  ResolveStridedMetadata.cpp
+  FoldMemRefsOps.cpp
   MaskedloadToLoad.cpp
+  ResolveStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
new file mode 100644
index 0000000000000..a3fdc7ee385ed
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -0,0 +1,97 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+
+struct AmdgpuFoldMemRefOpsPass final
+    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuFoldMemRefOpsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Value memrefSource;
+    SmallVector<Value> sourceIndices;
+    auto foldResult =
+        llvm::TypeSwitch<Operation *, LogicalResult>(
+            op.getSrc().getDefiningOp())
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+              // If the source is a SubViewOp, we can directly rewrite the
+              // GatherToLDSOp.
+              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                  rewriter, loc, subviewOp.getMixedOffsets(),
+                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                  op.getSrcIndices(), sourceIndices);
+              memrefSource = subviewOp.getSource();
+              return success();
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expandShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
+                          sourceIndices, false))) {
+                    return failure();
+                  }
+                  memrefSource = expandShapeOp.getViewSource();
+                  return success();
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapseShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                          sourceIndices))) {
+                    return failure();
+                  }
+                  memrefSource = collapseShapeOp.getViewSource();
+                  return success();
+                })
+            .Default([&](Operation *op) {
+              // If the source is not a SubViewOp, ExpandShapeOp, or
+              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+              return rewriter.notifyMatchFailure(
+                  op,
+                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
+                  "CollapseShapeOp");
+            });
+
+    if (failed(foldResult)) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+                                               op.getDst(), op.getDstIndices(),
+                                               op.getTransferType());
+
+    return success();
+  }
+};
+
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit) {
+  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
+} // namespace mlir::amdgpu
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188af9129..24da447ad7685 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-///    : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-///          memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
-  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
-  // Traverse all reassociation groups to determine the appropriate indices
-  // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-    if (groupSize == 1) {
-      sourceIndices.push_back(indices[group[0]]);
-      continue;
-    }
-    SmallVector<OpFoldResult> groupBasis =
-        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
-    SmallVector<Value> groupIndices =
-        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
-    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
-    sourceIndices.push_back(collapsedIndex);
-  }
-  return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-///    : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-///          memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
-  // Note: collapse_shape requires a strided memref, we can do this.
-  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-      loc, collapseShapeOp.getSrc());
-  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
-  for (auto [index, group] :
-       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-
-    if (groupSize == 1) {
-      sourceIndices.push_back(index);
-      continue;
-    }
-
-    SmallVector<OpFoldResult> basis =
-        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
-    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        loc, index, basis, /*hasOuterBound=*/true);
-    llvm::append_range(sourceIndices, delinearize.getResults());
-  }
-  if (collapseShapeOp.getReassociationIndices().empty()) {
-    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    int64_t srcRank =
-        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
-    for (int64_t i = 0; i < srcRank; i++) {
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-    }
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cfc74708..97fe3cb5b4705 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+  // Traverse all reassociation groups to determine the appropriate indices
+  // corresponding to each one of them post op folding.
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
+    }
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
+  }
+  return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
+    }
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    int64_t srcRank =
+        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+    for (int64_t i = 0; i < srcRank; i++) {
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    }
+  }
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
new file mode 100644
index 0000000000000..57afa127c9da8
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt --amdgpu-fold-memrefs-ops --split-input-file %s | FileCheck %s
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_subview_folding
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]], %[[ARG1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @subview_folding_offset
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+  // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[32, 64][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1], offset: 4160>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+
+  %alloc = memre...
[truncated]

@rupprecht
Copy link
Collaborator

The reapply also fixes a Bazel build issue, which was the reason of the revert

The buildbot failures I see on the original commit are all from non-Bazel builds, so the more important thing to do in this PR is to update the relevant CMakeLists.txt file w/ whatever dep needs to be included. Bazel failures are generally non-blocking.

The link error looks similar to the bazel error, so I think you need to add MLIRAffineUtils to the deps list in mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt.

The bazel change here LGTM. Thanks!

@lialan lialan requested review from krzysz00 and removed request for keith, rupprecht and aaronmondal July 23, 2025 23:53
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lialan lialan merged commit 1c3e4e9 into llvm:main Jul 24, 2025
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 24, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir,utils at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/16026

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x5adc35727a00 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 42]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x5adc35727a00 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 42] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

@lialan lialan deleted the lialan/fold_memrefs branch July 24, 2025 13:35
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 24, 2025

LLVM Buildbot has detected a new failure on builder premerge-monolithic-linux running on premerge-linux-1 while building mlir,utils at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/39210

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
...
PASS: lld :: COFF/duplicate-dwarf.s (98851 of 101866)
PASS: lld :: COFF/duplicate-cv.s (98852 of 101866)
PASS: lld :: COFF/delayimports-error.test (98853 of 101866)
PASS: lld :: COFF/duplicate-absolute.s (98854 of 101866)
PASS: lld :: COFF/baserel.test (98855 of 101866)
PASS: lld :: COFF/defparser.test (98856 of 101866)
PASS: lld :: COFF/arm64x-import.test (98857 of 101866)
PASS: lld :: COFF/def-name.test (98858 of 101866)
PASS: lld :: COFF/duplicate.test (98859 of 101866)
TIMEOUT: MLIR :: Examples/standalone/test.toy (98860 of 101866)
******************** TEST 'MLIR :: Examples/standalone/test.toy' FAILED ********************
Exit Code: 1
Timeout: Reached timeout of 60 seconds

Command Output (stdout):
--
# RUN: at line 1
"/etc/cmake/bin/cmake" "/build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone" -G "Ninja"  -DCMAKE_CXX_COMPILER=/usr/bin/clang++ -DCMAKE_C_COMPILER=/usr/bin/clang  -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir  -DLLVM_USE_LINKER=lld  -DPython3_EXECUTABLE="/usr/bin/python3.10"
# executed command: /etc/cmake/bin/cmake /build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone -G Ninja -DCMAKE_CXX_COMPILER=/usr/bin/clang++ -DCMAKE_C_COMPILER=/usr/bin/clang -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir -DLLVM_USE_LINKER=lld -DPython3_EXECUTABLE=/usr/bin/python3.10
# .---command stdout------------
# | -- The CXX compiler identification is Clang 16.0.6
# | -- The C compiler identification is Clang 16.0.6
# | -- Detecting CXX compiler ABI info
# | -- Detecting CXX compiler ABI info - done
# | -- Check for working CXX compiler: /usr/bin/clang++ - skipped
# | -- Detecting CXX compile features
# | -- Detecting CXX compile features - done
# | -- Detecting C compiler ABI info
# | -- Detecting C compiler ABI info - done
# | -- Check for working C compiler: /usr/bin/clang - skipped
# | -- Detecting C compile features
# | -- Detecting C compile features - done
# | -- Looking for histedit.h
# | -- Looking for histedit.h - found
# | -- Found LibEdit: /usr/include (found version "2.11") 
# | -- Found ZLIB: /usr/lib/x86_64-linux-gnu/libz.so (found version "1.2.11") 
# | -- Found LibXml2: /usr/lib/x86_64-linux-gnu/libxml2.so (found version "2.9.13") 
# | -- Using MLIRConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir
# | -- Using LLVMConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/llvm
# | -- Linker detection: unknown
# | -- Performing Test LLVM_LIBSTDCXX_MIN
# | -- Performing Test LLVM_LIBSTDCXX_MIN - Success
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR - Success
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER - Success
# | -- Performing Test C_SUPPORTS_FPIC
# | -- Performing Test C_SUPPORTS_FPIC - Success
# | -- Performing Test CXX_SUPPORTS_FPIC

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…nto `amdgpu.gather_to_lds`" (llvm#150334)

This is a reapply of patch llvm#149851. The reapply also fixes a CMake/Bazel
build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (llvm#149851) message:
-----
This PR adds a new optimization pass to fold
`memref.subview/expand_shape/collapse_shape` ops into consumer
`amdgpu.gather_to_lds` operations.
* Implements a new pass `AmdgpuFoldMemRefOpsPass` with pattern
`FoldMemRefOpsIntoGatherToLDSOp`
* Adds corresponding folding tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU bazel "Peripheral" support tier build system: utils/bazel mlir:amdgpu mlir:gpu mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants