Skip to content

Commit 8892d4f

Browse files
[CIR][ThroughMLIR] Lower simple SwitchOp.
1 parent a725efb commit 8892d4f

File tree

5 files changed

+287
-28
lines changed

5 files changed

+287
-28
lines changed

clang/include/clang/CIR/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#include <memory>
1919

2020
namespace cir {
21+
/// Create a pass for transforming CIR operations to more 'scf' dialect-friendly
22+
/// forms. It rewrites operations that aren't supported by 'scf', such as breaks
23+
/// and continues.
24+
std::unique_ptr<mlir::Pass> createMLIRLoweringPreparePass();
25+
2126
/// Create a pass for lowering from MLIR builtin dialects such as `Affine` and
2227
/// `Std`, to the LLVM dialect for codegen.
2328
std::unique_ptr<mlir::Pass> createConvertMLIRToLLVMPass();

clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_clang_library(clangCIRLoweringThroughMLIR
99
LowerCIRLoopToSCF.cpp
1010
LowerCIRToMLIR.cpp
1111
LowerMLIRToLLVM.cpp
12+
MLIRLoweringPrepare.cpp
1213

1314
DEPENDS
1415
MLIRCIROpsIncGen

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 119 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,17 @@
4848
#include "mlir/Transforms/DialectConversion.h"
4949
#include "clang/CIR/Dialect/IR/CIRDialect.h"
5050
#include "clang/CIR/Dialect/IR/CIRTypes.h"
51+
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
5152
#include "clang/CIR/LowerToLLVM.h"
5253
#include "clang/CIR/LowerToMLIR.h"
5354
#include "clang/CIR/LoweringHelpers.h"
5455
#include "clang/CIR/Passes.h"
5556
#include "llvm/ADT/STLExtras.h"
56-
#include "llvm/Support/ErrorHandling.h"
57-
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
58-
#include "clang/CIR/LowerToLLVM.h"
59-
#include "clang/CIR/Passes.h"
6057
#include "llvm/ADT/Sequence.h"
6158
#include "llvm/ADT/SmallVector.h"
6259
#include "llvm/ADT/TypeSwitch.h"
6360
#include "llvm/IR/Value.h"
61+
#include "llvm/Support/ErrorHandling.h"
6462
#include "llvm/Support/TimeProfiler.h"
6563

6664
using namespace cir;
@@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946944
} else {
947945
// For scopes with results, use scf.execute_region
948946
SmallVector<mlir::Type> types;
949-
if (mlir::failed(
950-
getTypeConverter()->convertTypes(scopeOp->getResultTypes(), types)))
947+
if (mlir::failed(getTypeConverter()->convertTypes(
948+
scopeOp->getResultTypes(), types)))
951949
return mlir::failure();
952950
auto exec =
953951
rewriter.create<mlir::scf::ExecuteRegionOp>(scopeOp.getLoc(), types);
@@ -1515,28 +1513,117 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15151513
}
15161514
};
15171515

1516+
class CIRSwitchOpLowering : public mlir::OpConversionPattern<cir::SwitchOp> {
1517+
public:
1518+
using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1519+
1520+
mlir::LogicalResult
1521+
matchAndRewrite(cir::SwitchOp op, OpAdaptor adaptor,
1522+
mlir::ConversionPatternRewriter &rewriter) const override {
1523+
rewriter.setInsertionPointAfter(op);
1524+
llvm::SmallVector<CaseOp> cases;
1525+
if (!op.isSimpleForm(cases))
1526+
llvm_unreachable("NYI");
1527+
1528+
llvm::SmallVector<int64_t> caseValues;
1529+
// Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1530+
// This is necessary because some CaseOp might carry 0 or multiple values.
1531+
llvm::DenseMap<size_t, unsigned> indexMap;
1532+
caseValues.reserve(cases.size());
1533+
for (auto [i, caseOp] : llvm::enumerate(cases)) {
1534+
switch (caseOp.getKind()) {
1535+
case CaseOpKind::Equal: {
1536+
auto valueAttr = caseOp.getValue()[0];
1537+
auto value = cast<cir::IntAttr>(valueAttr);
1538+
indexMap[i] = caseValues.size();
1539+
caseValues.push_back(value.getUInt());
1540+
break;
1541+
}
1542+
case CaseOpKind::Default:
1543+
break;
1544+
case CaseOpKind::Range:
1545+
case CaseOpKind::Anyof:
1546+
llvm_unreachable("NYI");
1547+
}
1548+
}
1549+
1550+
auto operand = adaptor.getOperands()[0];
1551+
// `scf.index_switch` expects an index of type `index`.
1552+
auto indexType = mlir::IndexType::get(getContext());
1553+
auto indexCast = rewriter.create<mlir::arith::IndexCastOp>(
1554+
op.getLoc(), indexType, operand);
1555+
auto indexSwitch = rewriter.create<mlir::scf::IndexSwitchOp>(
1556+
op.getLoc(), mlir::TypeRange{}, indexCast, caseValues, cases.size());
1557+
1558+
bool metDefault = false;
1559+
for (auto [i, caseOp] : llvm::enumerate(cases)) {
1560+
auto &region = caseOp.getRegion();
1561+
switch (caseOp.getKind()) {
1562+
case CaseOpKind::Equal: {
1563+
auto &caseRegion = indexSwitch.getCaseRegions()[indexMap[i]];
1564+
rewriter.inlineRegionBefore(region, caseRegion, caseRegion.end());
1565+
break;
1566+
}
1567+
case CaseOpKind::Default: {
1568+
auto &defaultRegion = indexSwitch.getDefaultRegion();
1569+
rewriter.inlineRegionBefore(region, defaultRegion, defaultRegion.end());
1570+
metDefault = true;
1571+
break;
1572+
}
1573+
case CaseOpKind::Range:
1574+
case CaseOpKind::Anyof:
1575+
llvm_unreachable("NYI");
1576+
}
1577+
}
1578+
1579+
// `scf.index_switch` expects its default region to contain exactly one
1580+
// block. If we don't have a default region in `cir.switch`, we need to
1581+
// supply it here.
1582+
if (!metDefault) {
1583+
auto &defaultRegion = indexSwitch.getDefaultRegion();
1584+
mlir::Block *block =
1585+
rewriter.createBlock(&defaultRegion, defaultRegion.end());
1586+
rewriter.setInsertionPointToEnd(block);
1587+
rewriter.create<mlir::scf::YieldOp>(op.getLoc());
1588+
}
1589+
1590+
// The final `cir.break` should be replaced to `scf.yield`.
1591+
// After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1592+
for (auto &region : indexSwitch.getCaseRegions()) {
1593+
auto &lastBlock = region.back();
1594+
auto &lastOp = lastBlock.back();
1595+
assert(isa<BreakOp>(lastOp));
1596+
rewriter.setInsertionPointAfter(&lastOp);
1597+
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(&lastOp);
1598+
}
1599+
1600+
rewriter.replaceOp(op, indexSwitch);
1601+
1602+
return mlir::success();
1603+
}
1604+
};
1605+
15181606
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
15191607
mlir::TypeConverter &converter) {
15201608
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
15211609

1522-
patterns
1523-
.add<CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524-
CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525-
CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526-
CIRFuncOpLowering, CIRBrCondOpLowering,
1527-
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528-
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529-
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530-
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531-
CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532-
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533-
CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534-
CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535-
CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536-
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537-
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538-
CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539-
CIRTrapOpLowering>(converter, patterns.getContext());
1610+
patterns.add<
1611+
CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1612+
CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1613+
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1614+
CIRAllocaOpLowering, CIRFuncOpLowering, CIRBrCondOpLowering,
1615+
CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1616+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1617+
CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1618+
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1619+
CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1620+
CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1621+
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1622+
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1623+
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1624+
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1625+
CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1626+
CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext());
15401627
}
15411628

15421629
static mlir::TypeConverter prepareTypeConverter() {
@@ -1610,7 +1697,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16101697
mlir::ModuleOp theModule = getOperation();
16111698

16121699
auto converter = prepareTypeConverter();
1613-
1700+
16141701
mlir::RewritePatternSet patterns(&getContext());
16151702

16161703
populateCIRLoopToSCFConversionPatterns(patterns, converter);
@@ -1628,10 +1715,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281715
// cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291716
// here, then MLIR would think any remaining `cir.continue` indicates a
16301717
// failure, which is not what we want.
1631-
1632-
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331718

1634-
if (mlir::failed(mlir::applyPartialConversion(theModule, target,
1719+
patterns.add<CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1720+
CIRYieldOpLowering>(converter, context);
1721+
1722+
if (mlir::failed(mlir::applyPartialConversion(theModule, target,
16351723
std::move(patterns)))) {
16361724
signalPassFailure();
16371725
}
@@ -1646,6 +1734,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
16461734

16471735
mlir::PassManager pm(mlirCtx);
16481736

1737+
pm.addPass(createMLIRLoweringPreparePass());
16491738
pm.addPass(createConvertCIRToMLIRPass());
16501739
pm.addPass(createConvertMLIRToLLVMPass());
16511740

@@ -1712,6 +1801,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
17121801
llvm::TimeTraceScope scope("Lower CIR To MLIR");
17131802

17141803
mlir::PassManager pm(mlirCtx);
1804+
1805+
pm.addPass(createMLIRLoweringPreparePass());
17151806
pm.addPass(createConvertCIRToMLIRPass());
17161807

17171808
auto result = !mlir::failed(pm.run(theModule));
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#include "mlir/IR/BuiltinOps.h"
2+
#include "mlir/IR/IRMapping.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "mlir/Transforms/DialectConversion.h"
5+
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
6+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
7+
8+
using namespace llvm;
9+
using namespace cir;
10+
11+
namespace cir {
12+
13+
struct MLIRLoweringPrepare
14+
: public mlir::PassWrapper<MLIRLoweringPrepare,
15+
mlir::OperationPass<mlir::ModuleOp>> {
16+
// `scf.index_switch` requires that switch branches do not fall through.
17+
// We need to copy the next branch's body when the current `cir.case` does not
18+
// terminate with a break.
19+
void removeFallthrough(llvm::SmallVector<CaseOp> &cases);
20+
21+
void runOnOp(mlir::Operation *op);
22+
void runOnOperation() final;
23+
24+
StringRef getDescription() const override {
25+
return "Rewrite CIR module to be more 'scf' dialect-friendly";
26+
}
27+
28+
StringRef getArgument() const override { return "mlir-lowering-prepare"; }
29+
};
30+
31+
// `scf.index_switch` requires that switch branches do not fall through.
32+
// We need to copy the next branch's body when the current `cir.case` does not
33+
// terminate with a break.
34+
void MLIRLoweringPrepare::removeFallthrough(llvm::SmallVector<CaseOp> &cases) {
35+
CIRBaseBuilderTy builder(getContext());
36+
// Note we enumerate in the reverse order, to facilitate the cloning.
37+
for (auto it = cases.rbegin(); it != cases.rend(); it++) {
38+
auto caseOp = *it;
39+
auto &region = caseOp.getRegion();
40+
auto &lastBlock = region.back();
41+
mlir::Operation &last = lastBlock.back();
42+
if (isa<BreakOp>(last))
43+
continue;
44+
45+
// The last op must be a `cir.yield`. As it falls through, we copy the
46+
// previous case's body to this one.
47+
if (!isa<YieldOp>(last)) {
48+
caseOp->dump();
49+
continue;
50+
}
51+
assert(isa<YieldOp>(last));
52+
53+
// If there's no previous case, we can simply change the yield into a break.
54+
if (it == cases.rbegin()) {
55+
builder.setInsertionPointAfter(&last);
56+
builder.create<BreakOp>(last.getLoc());
57+
last.erase();
58+
continue;
59+
}
60+
61+
auto prevIt = it;
62+
--prevIt;
63+
CaseOp &prev = *prevIt;
64+
auto &prevRegion = prev.getRegion();
65+
mlir::IRMapping mapping;
66+
builder.cloneRegionBefore(prevRegion, region, region.end());
67+
68+
// We inline the block to the end.
69+
// This is required because `scf.index_switch` expects that each of its
70+
// region contains a single block.
71+
mlir::Block *cloned = lastBlock.getNextNode();
72+
for (auto it = cloned->begin(); it != cloned->end();) {
73+
auto next = it;
74+
next++;
75+
it->moveBefore(&last);
76+
it = next;
77+
}
78+
cloned->erase();
79+
last.erase();
80+
}
81+
}
82+
83+
void MLIRLoweringPrepare::runOnOp(mlir::Operation *op) {
84+
if (auto switchOp = dyn_cast<SwitchOp>(op)) {
85+
llvm::SmallVector<CaseOp> cases;
86+
if (!switchOp.isSimpleForm(cases))
87+
llvm_unreachable("NYI");
88+
89+
removeFallthrough(cases);
90+
return;
91+
}
92+
llvm_unreachable("unexpected op type");
93+
}
94+
95+
void MLIRLoweringPrepare::runOnOperation() {
96+
auto module = getOperation();
97+
98+
llvm::SmallVector<mlir::Operation *> opsToTransform;
99+
module->walk([&](mlir::Operation *op) {
100+
if (isa<SwitchOp>(op))
101+
opsToTransform.push_back(op);
102+
});
103+
104+
for (auto *op : opsToTransform)
105+
runOnOp(op);
106+
}
107+
108+
std::unique_ptr<mlir::Pass> createMLIRLoweringPreparePass() {
109+
return std::make_unique<MLIRLoweringPrepare>();
110+
}
111+
112+
} // namespace cir
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void fallthrough() {
5+
int i = 0;
6+
switch (i) {
7+
case 2:
8+
i++;
9+
case 3:
10+
i++;
11+
break;
12+
case 8:
13+
i++;
14+
}
15+
16+
// This should copy the `i++; break` in case 3 to case 2.
17+
18+
// CHECK: memref.alloca_scope {
19+
// CHECK: %[[I:.+]] = memref.load %alloca[]
20+
// CHECK: %[[CASTED:.+]] = arith.index_cast %[[I]]
21+
// CHECK: scf.index_switch %[[CASTED]]
22+
// CHECK: case 2 {
23+
// CHECK: %[[I:.+]] = memref.load %alloca[]
24+
// CHECK: %[[ONE:.+]] = arith.constant 1
25+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
26+
// CHECK: memref.store %[[ADD]], %alloca[]
27+
// CHECK: %[[I:.+]] = memref.load %alloca[]
28+
// CHECK: %[[ONE:.+]] = arith.constant 1
29+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
30+
// CHECK: memref.store %[[ADD]], %alloca[]
31+
// CHECK: scf.yield
32+
// CHECK: }
33+
// CHECK: case 3 {
34+
// CHECK: %[[I:.+]] = memref.load %alloca[]
35+
// CHECK: %[[ONE:.+]] = arith.constant 1
36+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
37+
// CHECK: memref.store %[[ADD]], %alloca[]
38+
// CHECK: scf.yield
39+
// CHECK: }
40+
// CHECK: case 8 {
41+
// CHECK: %[[I:.+]] = memref.load %alloca[]
42+
// CHECK: %[[ONE:.+]] = arith.constant 1
43+
// CHECK: %[[ADD:.+]] = arith.addi %[[I]], %[[ONE]]
44+
// CHECK: memref.store %[[ADD]], %alloca[]
45+
// CHECK: scf.yield
46+
// CHECK: }
47+
// CHECK: default {
48+
// CHECK: }
49+
// CHECK: }
50+
}

0 commit comments

Comments
 (0)