|  | 
|  | 1 | +//====- UnifyFuncReturn.cpp -------------------------------------*- C++ -*-===// | 
|  | 2 | +// | 
|  | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | 4 | +// See https://llvm.org/LICENSE.txt for license information. | 
|  | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | 6 | +// | 
|  | 7 | +//===----------------------------------------------------------------------===// | 
|  | 8 | + | 
|  | 9 | +#include "PassDetail.h" | 
|  | 10 | +#include "mlir/IR/PatternMatch.h" | 
|  | 11 | +#include "mlir/Support/LogicalResult.h" | 
|  | 12 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | 
|  | 13 | +#include "clang/CIR/Dialect/IR/CIRDialect.h" | 
|  | 14 | +#include "clang/CIR/Dialect/Passes.h" | 
|  | 15 | +#include "llvm/Support/TimeProfiler.h" | 
|  | 16 | + | 
|  | 17 | +using namespace mlir; | 
|  | 18 | +using namespace cir; | 
|  | 19 | + | 
|  | 20 | +namespace { | 
|  | 21 | + | 
|  | 22 | +struct UnifyFuncReturnPass : public UnifyFuncReturnBase<UnifyFuncReturnPass> { | 
|  | 23 | +  UnifyFuncReturnPass() = default; | 
|  | 24 | +  void runOnOperation() override; | 
|  | 25 | + | 
|  | 26 | +private: | 
|  | 27 | +  void unifyReturn(FuncOp func); | 
|  | 28 | +}; | 
|  | 29 | + | 
|  | 30 | +struct UnifyReturn : public OpRewritePattern<ReturnOp> { | 
|  | 31 | +  using OpRewritePattern<ReturnOp>::OpRewritePattern; | 
|  | 32 | + | 
|  | 33 | +  UnifyReturn(MLIRContext *context, cir::FuncOp func, Block *retBlock) | 
|  | 34 | +      : OpRewritePattern<ReturnOp>(context), func(func), retBlock(retBlock) {} | 
|  | 35 | + | 
|  | 36 | +  mlir::LogicalResult | 
|  | 37 | +  matchAndRewrite(cir::ReturnOp ret, | 
|  | 38 | +                  mlir::PatternRewriter &rewriter) const override { | 
|  | 39 | +    mlir::OpBuilder::InsertionGuard guard(rewriter); | 
|  | 40 | +    auto fn = ret->getParentOfType<cir::FuncOp>(); | 
|  | 41 | +    if (!fn || fn != func) | 
|  | 42 | +      return mlir::failure(); | 
|  | 43 | +    // Replace 'return' with 'br <retBlock>' | 
|  | 44 | +    rewriter.replaceOpWithNewOp<cir::BrOp>(ret, ret.getInput(), retBlock); | 
|  | 45 | +    return mlir::success(); | 
|  | 46 | +  } | 
|  | 47 | + | 
|  | 48 | +private: | 
|  | 49 | +  cir::FuncOp func; | 
|  | 50 | +  Block *retBlock; | 
|  | 51 | +}; | 
|  | 52 | + | 
|  | 53 | +} // namespace | 
|  | 54 | + | 
|  | 55 | +void UnifyFuncReturnPass::unifyReturn(cir::FuncOp func) { | 
|  | 56 | +  if (func.getRegion().empty()) | 
|  | 57 | +    return; | 
|  | 58 | + | 
|  | 59 | +  bool hasRetVals = func.getNumResults() > 0; | 
|  | 60 | +  auto *endBody = &func.getBody().back(); | 
|  | 61 | +  auto *retBlock = endBody->splitBlock(endBody->end()); | 
|  | 62 | +  if (hasRetVals) | 
|  | 63 | +    retBlock->addArguments(func.getResultTypes(), func.getLoc()); | 
|  | 64 | + | 
|  | 65 | +  RewritePatternSet patterns(&getContext()); | 
|  | 66 | +  patterns.add<UnifyReturn>(patterns.getContext(), func, retBlock); | 
|  | 67 | + | 
|  | 68 | +  // Collect operations to apply patterns. | 
|  | 69 | +  llvm::SmallVector<Operation *, 16> ops; | 
|  | 70 | +  func->walk([&](cir::ReturnOp op) { ops.push_back(op.getOperation()); }); | 
|  | 71 | + | 
|  | 72 | +  // Apply patterns. | 
|  | 73 | +  if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) | 
|  | 74 | +    signalPassFailure(); | 
|  | 75 | + | 
|  | 76 | +  auto builder = OpBuilder::atBlockBegin(retBlock); | 
|  | 77 | +  if (hasRetVals) | 
|  | 78 | +    builder.create<cir::ReturnOp>(func.getLoc(), retBlock->getArguments()); | 
|  | 79 | +  else | 
|  | 80 | +    builder.create<cir::ReturnOp>(func.getLoc()); | 
|  | 81 | +} | 
|  | 82 | + | 
|  | 83 | +void UnifyFuncReturnPass::runOnOperation() { | 
|  | 84 | +  llvm::TimeTraceScope scope("Unify function returns"); | 
|  | 85 | + | 
|  | 86 | +  // Collect operations to apply patterns. | 
|  | 87 | +  llvm::SmallVector<Operation *, 16> ops; | 
|  | 88 | +  getOperation()->walk([&](cir::FuncOp op) { unifyReturn(op); }); | 
|  | 89 | +} | 
|  | 90 | + | 
|  | 91 | +namespace mlir { | 
|  | 92 | +std::unique_ptr<Pass> createUnifyFuncReturnPass() { | 
|  | 93 | +  return std::make_unique<UnifyFuncReturnPass>(); | 
|  | 94 | +} | 
|  | 95 | +} // namespace mlir | 
0 commit comments