4848#include " mlir/Transforms/DialectConversion.h"
4949#include " clang/CIR/Dialect/IR/CIRDialect.h"
5050#include " clang/CIR/Dialect/IR/CIRTypes.h"
51+ #include " clang/CIR/Interfaces/CIRLoopOpInterface.h"
5152#include " clang/CIR/LowerToLLVM.h"
5253#include " clang/CIR/LowerToMLIR.h"
5354#include " clang/CIR/LoweringHelpers.h"
5455#include " clang/CIR/Passes.h"
5556#include " llvm/ADT/STLExtras.h"
56- #include " llvm/Support/ErrorHandling.h"
57- #include " clang/CIR/Interfaces/CIRLoopOpInterface.h"
58- #include " clang/CIR/LowerToLLVM.h"
59- #include " clang/CIR/Passes.h"
6057#include " llvm/ADT/Sequence.h"
6158#include " llvm/ADT/SmallVector.h"
6259#include " llvm/ADT/TypeSwitch.h"
6360#include " llvm/IR/Value.h"
61+ #include " llvm/Support/ErrorHandling.h"
6462#include " llvm/Support/TimeProfiler.h"
6563
6664using namespace cir ;
@@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946944 } else {
947945 // For scopes with results, use scf.execute_region
948946 SmallVector<mlir::Type> types;
949- if (mlir::failed (
950- getTypeConverter ()-> convertTypes ( scopeOp->getResultTypes (), types)))
947+ if (mlir::failed (getTypeConverter ()-> convertTypes (
948+ scopeOp->getResultTypes (), types)))
951949 return mlir::failure ();
952950 auto exec =
953951 rewriter.create <mlir::scf::ExecuteRegionOp>(scopeOp.getLoc (), types);
@@ -1515,28 +1513,117 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15151513 }
15161514};
15171515
1516+ class CIRSwitchOpLowering : public mlir ::OpConversionPattern<cir::SwitchOp> {
1517+ public:
1518+ using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1519+
1520+ mlir::LogicalResult
1521+ matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1522+ mlir::ConversionPatternRewriter &rewriter) const override {
1523+ rewriter.setInsertionPointAfter (op);
1524+ llvm::SmallVector<CaseOp> cases;
1525+ if (!op.isSimpleForm (cases))
1526+ llvm_unreachable (" NYI" );
1527+
1528+ llvm::SmallVector<int64_t > caseValues;
1529+ // Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1530+ // This is necessary because some CaseOp might carry 0 or multiple values.
1531+ llvm::DenseMap<size_t , unsigned > indexMap;
1532+ caseValues.reserve (cases.size ());
1533+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1534+ switch (caseOp.getKind ()) {
1535+ case CaseOpKind::Equal: {
1536+ auto valueAttr = caseOp.getValue ()[0 ];
1537+ auto value = cast<cir::IntAttr>(valueAttr);
1538+ indexMap[i] = caseValues.size ();
1539+ caseValues.push_back (value.getUInt ());
1540+ break ;
1541+ }
1542+ case CaseOpKind::Default:
1543+ break ;
1544+ case CaseOpKind::Range:
1545+ case CaseOpKind::Anyof:
1546+ llvm_unreachable (" NYI" );
1547+ }
1548+ }
1549+
1550+ auto operand = adaptor.getOperands ()[0 ];
1551+ // `scf.index_switch` expects an index of type `index`.
1552+ auto indexType = mlir::IndexType::get (getContext ());
1553+ auto indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1554+ op.getLoc (), indexType, operand);
1555+ auto indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1556+ op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1557+
1558+ bool metDefault = false ;
1559+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1560+ auto ®ion = caseOp.getRegion ();
1561+ switch (caseOp.getKind ()) {
1562+ case CaseOpKind::Equal: {
1563+ auto &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1564+ rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1565+ break ;
1566+ }
1567+ case CaseOpKind::Default: {
1568+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1569+ rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1570+ metDefault = true ;
1571+ break ;
1572+ }
1573+ case CaseOpKind::Range:
1574+ case CaseOpKind::Anyof:
1575+ llvm_unreachable (" NYI" );
1576+ }
1577+ }
1578+
1579+ // `scf.index_switch` expects its default region to contain exactly one
1580+ // block. If we don't have a default region in `cir.switch`, we need to
1581+ // supply it here.
1582+ if (!metDefault) {
1583+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1584+ mlir::Block *block =
1585+ rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1586+ rewriter.setInsertionPointToEnd (block);
1587+ rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1588+ }
1589+
1590+ // The final `cir.break` should be replaced to `scf.yield`.
1591+ // After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1592+ for (auto ®ion : indexSwitch.getCaseRegions ()) {
1593+ auto &lastBlock = region.back ();
1594+ auto &lastOp = lastBlock.back ();
1595+ assert (isa<BreakOp>(lastOp));
1596+ rewriter.setInsertionPointAfter (&lastOp);
1597+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1598+ }
1599+
1600+ rewriter.replaceOp (op, indexSwitch);
1601+
1602+ return mlir::success ();
1603+ }
1604+ };
1605+
15181606void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
15191607 mlir::TypeConverter &converter) {
15201608 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15211609
1522- patterns
1523- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526- CIRFuncOpLowering, CIRBrCondOpLowering,
1527- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529- CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530- CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531- CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533- CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534- CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535- CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536- CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537- CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538- CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539- CIRTrapOpLowering>(converter, patterns.getContext ());
1610+ patterns.add <
1611+ CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1612+ CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1613+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1614+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRBrCondOpLowering,
1615+ CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1616+ CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1617+ CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1618+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1619+ CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1620+ CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1621+ CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1622+ CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1623+ CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1624+ CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1625+ CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1626+ CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext ());
15401627}
15411628
15421629static mlir::TypeConverter prepareTypeConverter () {
@@ -1610,7 +1697,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16101697 mlir::ModuleOp theModule = getOperation ();
16111698
16121699 auto converter = prepareTypeConverter ();
1613-
1700+
16141701 mlir::RewritePatternSet patterns (&getContext ());
16151702
16161703 populateCIRLoopToSCFConversionPatterns (patterns, converter);
@@ -1628,10 +1715,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281715 // cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291716 // here, then MLIR would think any remaining `cir.continue` indicates a
16301717 // failure, which is not what we want.
1631-
1632- patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331718
1634- if (mlir::failed (mlir::applyPartialConversion (theModule, target,
1719+ patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1720+ CIRYieldOpLowering>(converter, context);
1721+
1722+ if (mlir::failed (mlir::applyPartialConversion (theModule, target,
16351723 std::move (patterns)))) {
16361724 signalPassFailure ();
16371725 }
@@ -1646,6 +1734,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
16461734
16471735 mlir::PassManager pm (mlirCtx);
16481736
1737+ pm.addPass (createMLIRLoweringPreparePass ());
16491738 pm.addPass (createConvertCIRToMLIRPass ());
16501739 pm.addPass (createConvertMLIRToLLVMPass ());
16511740
@@ -1712,6 +1801,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
17121801 llvm::TimeTraceScope scope (" Lower CIR To MLIR" );
17131802
17141803 mlir::PassManager pm (mlirCtx);
1804+
1805+ pm.addPass (createMLIRLoweringPreparePass ());
17151806 pm.addPass (createConvertCIRToMLIRPass ());
17161807
17171808 auto result = !mlir::failed (pm.run (theModule));
0 commit comments