diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h index 1a54308264056..8a9e98330f895 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h @@ -77,6 +77,10 @@ class SelectionDAGISel { bool MatchFilterFuncName = false; StringRef FuncName; + // HwMode to be used by getValueTypeForHwMode. This will be initialized + // based on the subtarget used by the MachineFunction. + unsigned HwMode; + explicit SelectionDAGISel(TargetMachine &tm, CodeGenOptLevel OL = CodeGenOptLevel::Default); virtual ~SelectionDAGISel(); @@ -202,7 +206,9 @@ class SelectionDAGISel { // Space-optimized forms that implicitly encode VT. OPC_CheckTypeI32, OPC_CheckTypeI64, + OPC_CheckTypeByHwMode, OPC_CheckTypeRes, + OPC_CheckTypeResByHwMode, OPC_SwitchType, OPC_CheckChild0Type, OPC_CheckChild1Type, @@ -231,6 +237,15 @@ class SelectionDAGISel { OPC_CheckChild6TypeI64, OPC_CheckChild7TypeI64, + OPC_CheckChild0TypeByHwMode, + OPC_CheckChild1TypeByHwMode, + OPC_CheckChild2TypeByHwMode, + OPC_CheckChild3TypeByHwMode, + OPC_CheckChild4TypeByHwMode, + OPC_CheckChild5TypeByHwMode, + OPC_CheckChild6TypeByHwMode, + OPC_CheckChild7TypeByHwMode, + OPC_CheckInteger, OPC_CheckChild0Integer, OPC_CheckChild1Integer, @@ -261,10 +276,13 @@ class SelectionDAGISel { OPC_EmitIntegerI16, OPC_EmitIntegerI32, OPC_EmitIntegerI64, + OPC_EmitIntegerByHwMode, OPC_EmitRegister, OPC_EmitRegisterI32, OPC_EmitRegisterI64, + OPC_EmitRegisterByHwMode, OPC_EmitRegister2, + OPC_EmitRegisterByHwMode2, OPC_EmitConvertToTarget, OPC_EmitConvertToTarget0, OPC_EmitConvertToTarget1, @@ -290,6 +308,7 @@ class SelectionDAGISel { OPC_EmitCopyToRegTwoByte, OPC_EmitNodeXForm, OPC_EmitNode, + OPC_EmitNodeByHwMode, // Space-optimized forms that implicitly encode number of result VTs. OPC_EmitNode0, OPC_EmitNode1, @@ -301,6 +320,7 @@ class SelectionDAGISel { OPC_EmitNode1Chain, OPC_EmitNode2Chain, OPC_MorphNodeTo, + OPC_MorphNodeToByHwMode, // Space-optimized forms that implicitly encode number of result VTs. OPC_MorphNodeTo0, OPC_MorphNodeTo1, @@ -444,6 +464,10 @@ class SelectionDAGISel { llvm_unreachable("Tblgen should generate this!"); } + virtual MVT getValueTypeForHwMode(unsigned Index) const { + llvm_unreachable("Tblgen should generate the implementation of this!"); + } + void SelectCodeCommon(SDNode *NodeToMatch, const uint8_t *MatcherTable, unsigned TableSize); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index e092061fb5e04..2aa775115811a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -531,6 +531,8 @@ void SelectionDAGISel::initializeAnalysisResults( SP = &FAM.getResult(Fn); TTI = &FAM.getResult(Fn); + + HwMode = MF->getSubtarget().getHwMode(); } void SelectionDAGISel::initializeAnalysisResults(MachineFunctionPass &MFP) { @@ -588,6 +590,8 @@ void SelectionDAGISel::initializeAnalysisResults(MachineFunctionPass &MFP) { SP = &MFP.getAnalysis().getLayoutInfo(); TTI = &MFP.getAnalysis().getTTI(Fn); + + HwMode = MF->getSubtarget().getHwMode(); } bool SelectionDAGISel::runOnMachineFunction(MachineFunction &mf) { @@ -2747,6 +2751,14 @@ getSimpleVT(const uint8_t *MatcherTable, unsigned &MatcherIndex) { return static_cast(SimpleVT); } +/// Decode a HwMode VT in MatcherTable by calling getValueTypeForHwMode. +LLVM_ATTRIBUTE_ALWAYS_INLINE static MVT +getHwModeVT(const uint8_t *MatcherTable, unsigned &MatcherIndex, + const SelectionDAGISel &SDISel) { + unsigned Index = MatcherTable[MatcherIndex++]; + return SDISel.getValueTypeForHwMode(Index); +} + void SelectionDAGISel::Select_JUMP_TABLE_DEBUG_INFO(SDNode *N) { SDLoc dl(N); CurDAG->SelectNodeTo(N, TargetOpcode::JUMP_TABLE_DEBUG_INFO, MVT::Glue, @@ -3134,8 +3146,9 @@ static unsigned IsPredicateKnownToFail( return Index; case SelectionDAGISel::OPC_CheckType: case SelectionDAGISel::OPC_CheckTypeI32: - case SelectionDAGISel::OPC_CheckTypeI64: { - MVT::SimpleValueType VT; + case SelectionDAGISel::OPC_CheckTypeI64: + case SelectionDAGISel::OPC_CheckTypeByHwMode: { + MVT VT; switch (Opcode) { case SelectionDAGISel::OPC_CheckTypeI32: VT = MVT::i32; @@ -3143,17 +3156,25 @@ static unsigned IsPredicateKnownToFail( case SelectionDAGISel::OPC_CheckTypeI64: VT = MVT::i64; break; + case SelectionDAGISel::OPC_CheckTypeByHwMode: + VT = getHwModeVT(Table, Index, SDISel); + break; default: VT = getSimpleVT(Table, Index); break; } - Result = !::CheckType(VT, N, SDISel.TLI, SDISel.CurDAG->getDataLayout()); + Result = !::CheckType(VT.SimpleTy, N, SDISel.TLI, + SDISel.CurDAG->getDataLayout()); return Index; } - case SelectionDAGISel::OPC_CheckTypeRes: { + case SelectionDAGISel::OPC_CheckTypeRes: + case SelectionDAGISel::OPC_CheckTypeResByHwMode: { unsigned Res = Table[Index++]; - Result = !::CheckType(getSimpleVT(Table, Index), N.getValue(Res), - SDISel.TLI, SDISel.CurDAG->getDataLayout()); + MVT VT = Opcode == SelectionDAGISel::OPC_CheckTypeResByHwMode + ? getHwModeVT(Table, Index, SDISel) + : getSimpleVT(Table, Index); + Result = !::CheckType(VT.SimpleTy, N.getValue(Res), SDISel.TLI, + SDISel.CurDAG->getDataLayout()); return Index; } case SelectionDAGISel::OPC_CheckChild0Type: @@ -3179,8 +3200,16 @@ static unsigned IsPredicateKnownToFail( case SelectionDAGISel::OPC_CheckChild4TypeI64: case SelectionDAGISel::OPC_CheckChild5TypeI64: case SelectionDAGISel::OPC_CheckChild6TypeI64: - case SelectionDAGISel::OPC_CheckChild7TypeI64: { - MVT::SimpleValueType VT; + case SelectionDAGISel::OPC_CheckChild7TypeI64: + case SelectionDAGISel::OPC_CheckChild0TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild1TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild2TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild3TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild4TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild5TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild6TypeByHwMode: + case SelectionDAGISel::OPC_CheckChild7TypeByHwMode: { + MVT VT; unsigned ChildNo; if (Opcode >= SelectionDAGISel::OPC_CheckChild0TypeI32 && Opcode <= SelectionDAGISel::OPC_CheckChild7TypeI32) { @@ -3190,11 +3219,15 @@ static unsigned IsPredicateKnownToFail( Opcode <= SelectionDAGISel::OPC_CheckChild7TypeI64) { VT = MVT::i64; ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0TypeI64; + } else if (Opcode >= SelectionDAGISel::OPC_CheckChild0TypeByHwMode && + Opcode <= SelectionDAGISel::OPC_CheckChild7TypeByHwMode) { + VT = getHwModeVT(Table, Index, SDISel); + ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0TypeByHwMode; } else { VT = getSimpleVT(Table, Index); ChildNo = Opcode - SelectionDAGISel::OPC_CheckChild0Type; } - Result = !::CheckChildType(VT, N, SDISel.TLI, + Result = !::CheckChildType(VT.SimpleTy, N, SDISel.TLI, SDISel.CurDAG->getDataLayout(), ChildNo); return Index; } @@ -3701,8 +3734,9 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_CheckType: case OPC_CheckTypeI32: - case OPC_CheckTypeI64: { - MVT::SimpleValueType VT; + case OPC_CheckTypeI64: + case OPC_CheckTypeByHwMode: { + MVT VT; switch (Opcode) { case OPC_CheckTypeI32: VT = MVT::i32; @@ -3710,19 +3744,26 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_CheckTypeI64: VT = MVT::i64; break; + case OPC_CheckTypeByHwMode: + VT = getHwModeVT(MatcherTable, MatcherIndex, *this); + break; default: VT = getSimpleVT(MatcherTable, MatcherIndex); break; } - if (!::CheckType(VT, N, TLI, CurDAG->getDataLayout())) + if (!::CheckType(VT.SimpleTy, N, TLI, CurDAG->getDataLayout())) break; continue; } - case OPC_CheckTypeRes: { + case OPC_CheckTypeRes: + case OPC_CheckTypeResByHwMode: { unsigned Res = MatcherTable[MatcherIndex++]; - if (!::CheckType(getSimpleVT(MatcherTable, MatcherIndex), N.getValue(Res), - TLI, CurDAG->getDataLayout())) + MVT VT = Opcode == OPC_CheckTypeResByHwMode + ? getHwModeVT(MatcherTable, MatcherIndex, *this) + : getSimpleVT(MatcherTable, MatcherIndex); + if (!::CheckType(VT.SimpleTy, N.getValue(Res), TLI, + CurDAG->getDataLayout())) break; continue; } @@ -3832,6 +3873,21 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, break; continue; } + case OPC_CheckChild0TypeByHwMode: + case OPC_CheckChild1TypeByHwMode: + case OPC_CheckChild2TypeByHwMode: + case OPC_CheckChild3TypeByHwMode: + case OPC_CheckChild4TypeByHwMode: + case OPC_CheckChild5TypeByHwMode: + case OPC_CheckChild6TypeByHwMode: + case OPC_CheckChild7TypeByHwMode: { + MVT VT = getHwModeVT(MatcherTable, MatcherIndex, *this); + unsigned ChildNo = Opcode - OPC_CheckChild0TypeByHwMode; + if (!::CheckChildType(VT.SimpleTy, N, TLI, CurDAG->getDataLayout(), + ChildNo)) + break; + continue; + } case OPC_CheckCondCode: if (!::CheckCondCode(MatcherTable, MatcherIndex, N)) break; continue; @@ -3900,8 +3956,9 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_EmitIntegerI8: case OPC_EmitIntegerI16: case OPC_EmitIntegerI32: - case OPC_EmitIntegerI64: { - MVT::SimpleValueType VT; + case OPC_EmitIntegerI64: + case OPC_EmitIntegerByHwMode: { + MVT VT; switch (Opcode) { case OPC_EmitIntegerI8: VT = MVT::i8; @@ -3915,21 +3972,27 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_EmitIntegerI64: VT = MVT::i64; break; + case OPC_EmitIntegerByHwMode: + VT = getHwModeVT(MatcherTable, MatcherIndex, *this); + break; default: VT = getSimpleVT(MatcherTable, MatcherIndex); break; } int64_t Val = GetSignedVBR(MatcherTable, MatcherIndex); + Val = SignExtend64(Val, MVT(VT).getFixedSizeInBits()); RecordedNodes.emplace_back( - CurDAG->getSignedConstant(Val, SDLoc(NodeToMatch), VT, + CurDAG->getSignedConstant(Val, SDLoc(NodeToMatch), VT.SimpleTy, /*isTarget=*/true), nullptr); continue; } + case OPC_EmitRegister: case OPC_EmitRegisterI32: - case OPC_EmitRegisterI64: { - MVT::SimpleValueType VT; + case OPC_EmitRegisterI64: + case OPC_EmitRegisterByHwMode: { + MVT VT; switch (Opcode) { case OPC_EmitRegisterI32: VT = MVT::i32; @@ -3937,6 +4000,9 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_EmitRegisterI64: VT = MVT::i64; break; + case OPC_EmitRegisterByHwMode: + VT = getHwModeVT(MatcherTable, MatcherIndex, *this); + break; default: VT = getSimpleVT(MatcherTable, MatcherIndex); break; @@ -3945,11 +4011,14 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, RecordedNodes.emplace_back(CurDAG->getRegister(RegNo, VT), nullptr); continue; } - case OPC_EmitRegister2: { + case OPC_EmitRegister2: + case OPC_EmitRegisterByHwMode2: { // For targets w/ more than 256 register names, the register enum // values are stored in two bytes in the matcher table (just like // opcodes). - MVT::SimpleValueType VT = getSimpleVT(MatcherTable, MatcherIndex); + MVT VT = Opcode == OPC_EmitRegisterByHwMode2 + ? getHwModeVT(MatcherTable, MatcherIndex, *this) + : getSimpleVT(MatcherTable, MatcherIndex); unsigned RegNo = MatcherTable[MatcherIndex++]; RegNo |= MatcherTable[MatcherIndex++] << 8; RecordedNodes.emplace_back(CurDAG->getRegister(RegNo, VT), nullptr); @@ -4114,6 +4183,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, } case OPC_EmitNode: + case OPC_EmitNodeByHwMode: case OPC_EmitNode0: case OPC_EmitNode1: case OPC_EmitNode2: @@ -4123,6 +4193,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case OPC_EmitNode1Chain: case OPC_EmitNode2Chain: case OPC_MorphNodeTo: + case OPC_MorphNodeToByHwMode: case OPC_MorphNodeTo0: case OPC_MorphNodeTo1: case OPC_MorphNodeTo2: @@ -4183,11 +4254,20 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, else NumVTs = MatcherTable[MatcherIndex++]; SmallVector VTs; - for (unsigned i = 0; i != NumVTs; ++i) { - MVT::SimpleValueType VT = getSimpleVT(MatcherTable, MatcherIndex); - if (VT == MVT::iPTR) - VT = TLI->getPointerTy(CurDAG->getDataLayout()).SimpleTy; - VTs.push_back(VT); + if (Opcode == OPC_EmitNodeByHwMode || Opcode == OPC_MorphNodeToByHwMode) { + for (unsigned i = 0; i != NumVTs; ++i) { + MVT VT = getHwModeVT(MatcherTable, MatcherIndex, *this); + if (VT == MVT::iPTR) + VT = TLI->getPointerTy(CurDAG->getDataLayout()); + VTs.push_back(VT); + } + } else { + for (unsigned i = 0; i != NumVTs; ++i) { + MVT::SimpleValueType VT = getSimpleVT(MatcherTable, MatcherIndex); + if (VT == MVT::iPTR) + VT = TLI->getPointerTy(CurDAG->getDataLayout()).SimpleTy; + VTs.push_back(VT); + } } if (EmitNodeInfo & OPFL_Chain) @@ -4254,7 +4334,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // Create the node. MachineSDNode *Res = nullptr; bool IsMorphNodeTo = - Opcode == OPC_MorphNodeTo || + Opcode == OPC_MorphNodeTo || Opcode == OPC_MorphNodeToByHwMode || (Opcode >= OPC_MorphNodeTo0 && Opcode <= OPC_MorphNodeTo2GlueOutput); if (!IsMorphNodeTo) { // If this is a normal EmitNode command, just create the new node and diff --git a/llvm/test/TableGen/RegClassByHwMode.td b/llvm/test/TableGen/RegClassByHwMode.td index 2e5e6d08ddae6..508b098c9f44e 100644 --- a/llvm/test/TableGen/RegClassByHwMode.td +++ b/llvm/test/TableGen/RegClassByHwMode.td @@ -205,32 +205,11 @@ include "Common/RegClassByHwModeCommon.td" // ISEL-SDAG-NEXT: OPC_RecordMemRef, // ISEL-SDAG-NEXT: OPC_RecordNode, // #0 = 'st' chained node // ISEL-SDAG-NEXT: OPC_RecordChild1, // #1 = $val -// ISEL-SDAG-NEXT: OPC_CheckChild1TypeI64, +// ISEL-SDAG-NEXT: OPC_CheckChild1TypeByHwMode, /*{(*:i64),(m1:i64),(m2:i64)}*/0, // ISEL-SDAG-NEXT: OPC_RecordChild2, // #2 = $src -// ISEL-SDAG-NEXT: OPC_Scope /*2 children */, {{[0-9]+}}, // ->{{[0-9]+}} -// ISEL-SDAG-NEXT: OPC_CheckChild2TypeI32, +// ISEL-SDAG-NEXT: OPC_CheckChild2TypeByHwMode, /*{(*:i32),(m3:i64)}*/1, // ISEL-SDAG-NEXT: OPC_CheckPredicate0, // Predicate_unindexedstore // ISEL-SDAG-NEXT: OPC_CheckPredicate1, // Predicate_store -// ISEL-SDAG-NEXT: OPC_Scope /*3 children */, {{[0-9]+}}, // ->{{[0-9]+}} -// ISEL-SDAG-NEXT: OPC_CheckPatternPredicate0, // (Subtarget->hasAlignedRegisters()) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo0, TARGET_VAL(MyTarget::MY_STORE), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG: OPC_CheckPatternPredicate1, // (Subtarget->hasUnalignedRegisters()) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo0, TARGET_VAL(MyTarget::MY_STORE), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG: OPC_CheckPatternPredicate2, // !((Subtarget->hasAlignedRegisters())) && !((Subtarget->hasUnalignedRegisters())) && !((Subtarget->isPtr64())) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo0, TARGET_VAL(MyTarget::MY_STORE), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG-NEXT: OPC_CheckChild2TypeI64, -// ISEL-SDAG-NEXT: OPC_CheckPredicate0, // Predicate_unindexedstore -// ISEL-SDAG-NEXT: OPC_CheckPredicate1, // Predicate_store -// ISEL-SDAG-NEXT: OPC_CheckPatternPredicate3, // (Subtarget->isPtr64()) // ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, // ISEL-SDAG-NEXT: OPC_MorphNodeTo0, TARGET_VAL(MyTarget::MY_STORE), 0|OPFL_Chain|OPFL_MemRefs, @@ -238,33 +217,12 @@ include "Common/RegClassByHwModeCommon.td" // ISEL-SDAG-NEXT: OPC_RecordMemRef, // ISEL-SDAG-NEXT: OPC_RecordNode, // #0 = 'ld' chained node // ISEL-SDAG-NEXT: OPC_RecordChild1, // #1 = $src -// ISEL-SDAG-NEXT: OPC_CheckTypeI64, -// ISEL-SDAG-NEXT: OPC_Scope /*2 children */, {{[0-9]+}}, // ->{{[0-9]+}} -// ISEL-SDAG-NEXT: OPC_CheckChild1TypeI32, -// ISEL-SDAG-NEXT: OPC_CheckPredicate2, // Predicate_unindexedload -// ISEL-SDAG-NEXT: OPC_CheckPredicate3, // Predicate_load -// ISEL-SDAG-NEXT: OPC_Scope /*3 children */, {{[0-9]+}}, // ->{{[0-9]+}} -// ISEL-SDAG-NEXT: OPC_CheckPatternPredicate0, // (Subtarget->hasAlignedRegisters()) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo1, TARGET_VAL(MyTarget::MY_LOAD), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG: OPC_CheckPatternPredicate1, // (Subtarget->hasUnalignedRegisters()) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo1, TARGET_VAL(MyTarget::MY_LOAD), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG: OPC_CheckPatternPredicate2, // !((Subtarget->hasAlignedRegisters())) && !((Subtarget->hasUnalignedRegisters())) && !((Subtarget->isPtr64())) -// ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo1, TARGET_VAL(MyTarget::MY_LOAD), 0|OPFL_Chain|OPFL_MemRefs, - -// ISEL-SDAG: /*Scope*/ -// ISEL-SDAG-NEXT: OPC_CheckChild1TypeI64, +// ISEL-SDAG-NEXT: OPC_CheckChild1TypeByHwMode, /*{(*:i32),(m3:i64)}*/1, // ISEL-SDAG-NEXT: OPC_CheckPredicate2, // Predicate_unindexedload // ISEL-SDAG-NEXT: OPC_CheckPredicate3, // Predicate_load -// ISEL-SDAG-NEXT: OPC_CheckPatternPredicate3, // (Subtarget->isPtr64()) +// ISEL-SDAG-NEXT: OPC_CheckTypeI64, // ISEL-SDAG-NEXT: OPC_EmitMergeInputChains1_0, -// ISEL-SDAG-NEXT: OPC_MorphNodeTo1, TARGET_VAL(MyTarget::MY_LOAD), 0|OPFL_Chain|OPFL_MemRefs, +// ISEL-SDAG-NEXT: OPC_MorphNodeToByHwMode, TARGET_VAL(MyTarget::MY_LOAD), 0|OPFL_Chain|OPFL_MemRefs, diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp index c3e07979467de..3ed5fe5d806ae 100644 --- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp @@ -94,13 +94,15 @@ bool TypeSetByHwMode::isValueTypeByHwMode(bool AllowEmpty) const { return true; } -ValueTypeByHwMode TypeSetByHwMode::getValueTypeByHwMode() const { +ValueTypeByHwMode TypeSetByHwMode::getValueTypeByHwMode(bool SkipEmpty) const { assert(isValueTypeByHwMode(true) && "The type set has multiple types for at least one HW mode"); ValueTypeByHwMode VVT; VVT.PtrAddrSpace = AddrSpace; for (const auto &I : *this) { + if (SkipEmpty && I.second.empty()) + continue; MVT T = I.second.empty() ? MVT::Other : *I.second.begin(); VVT.insertTypeForMode(I.first, T); } @@ -1480,10 +1482,9 @@ static unsigned getPatternSize(const TreePatternNode &P, // Count children in the count if they are also nodes. for (const TreePatternNode &Child : P.children()) { if (!Child.isLeaf() && Child.getNumTypes()) { - const TypeSetByHwMode &T0 = Child.getExtType(0); - // At this point, all variable type sets should be simple, i.e. only - // have a default mode. - if (T0.getMachineValueType() != MVT::Other) { + // FIXME: Can we assume non-simple VTs should be counted? + auto VVT = Child.getType(0); + if (llvm::any_of(VVT, [](auto &P) { return P.second != MVT::Other; })) { Size += getPatternSize(Child, CGP); continue; } @@ -3321,7 +3322,7 @@ void TreePattern::dump() const { print(errs()); } // CodeGenDAGPatterns implementation // -CodeGenDAGPatterns::CodeGenDAGPatterns(const RecordKeeper &R) +CodeGenDAGPatterns::CodeGenDAGPatterns(const RecordKeeper &R, bool ExpandHwMode) : Records(R), Target(R), Intrinsics(R), LegalVTS(Target.getLegalValueTypes()), LegalPtrVTS(ComputeLegalPtrTypes()) { @@ -3341,7 +3342,8 @@ CodeGenDAGPatterns::CodeGenDAGPatterns(const RecordKeeper &R) // Break patterns with parameterized types into a series of patterns, // where each one has a fixed type and is predicated on the conditions // of the associated HW mode. - ExpandHwModeBasedTypes(); + if (ExpandHwMode) + ExpandHwModeBasedTypes(); // Infer instruction flags. For example, we can detect loads, // stores, and side effects in many cases by examining an diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h index 220fa43bf5037..7d93e9ce126d5 100644 --- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h +++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.h @@ -190,7 +190,7 @@ struct TypeSetByHwMode : public InfoByHwMode { SetType &getOrCreate(unsigned Mode) { return Map[Mode]; } bool isValueTypeByHwMode(bool AllowEmpty) const; - ValueTypeByHwMode getValueTypeByHwMode() const; + ValueTypeByHwMode getValueTypeByHwMode(bool SkipEmpty = false) const; LLVM_ATTRIBUTE_ALWAYS_INLINE bool isMachineValueType() const { @@ -672,6 +672,9 @@ class TreePatternNode : public RefCountedBase { // Type accessors. unsigned getNumTypes() const { return Types.size(); } + ValueTypeByHwMode getType(unsigned ResNo) const { + return Types[ResNo].getValueTypeByHwMode(/*SkipEmpty=*/true); + } const std::vector &getExtTypes() const { return Types; } const TypeSetByHwMode &getExtType(unsigned ResNo) const { return Types[ResNo]; @@ -1123,7 +1126,7 @@ class CodeGenDAGPatterns { unsigned NumScopes = 0; public: - CodeGenDAGPatterns(const RecordKeeper &R); + CodeGenDAGPatterns(const RecordKeeper &R, bool ExpandHwMode = true); CodeGenTarget &getTargetInfo() { return Target; } const CodeGenTarget &getTargetInfo() const { return Target; } diff --git a/llvm/utils/TableGen/DAGISelEmitter.cpp b/llvm/utils/TableGen/DAGISelEmitter.cpp index 65d2835018f1e..dc66967584f71 100644 --- a/llvm/utils/TableGen/DAGISelEmitter.cpp +++ b/llvm/utils/TableGen/DAGISelEmitter.cpp @@ -30,7 +30,7 @@ class DAGISelEmitter { const CodeGenDAGPatterns CGP; public: - explicit DAGISelEmitter(const RecordKeeper &R) : Records(R), CGP(R) {} + explicit DAGISelEmitter(const RecordKeeper &R) : Records(R), CGP(R, false) {} void run(raw_ostream &OS); }; } // End anonymous namespace @@ -89,13 +89,30 @@ struct PatternSortingPredicate { const TreePatternNode < = LHS->getSrcPattern(); const TreePatternNode &RT = RHS->getSrcPattern(); - MVT LHSVT = LT.getNumTypes() != 0 ? LT.getSimpleType(0) : MVT::Other; - MVT RHSVT = RT.getNumTypes() != 0 ? RT.getSimpleType(0) : MVT::Other; - if (LHSVT.isVector() != RHSVT.isVector()) - return RHSVT.isVector(); + bool LHSIsVector = false; + bool RHSIsVector = false; + bool LHSIsFP = false; + bool RHSIsFP = false; - if (LHSVT.isFloatingPoint() != RHSVT.isFloatingPoint()) - return RHSVT.isFloatingPoint(); + if (LT.getNumTypes() != 0) { + for (auto [_, VT] : LT.getType(0)) { + LHSIsVector |= VT.isVector(); + LHSIsFP |= VT.isFloatingPoint(); + } + } + + if (RT.getNumTypes() != 0) { + for (auto [_, VT] : RT.getType(0)) { + RHSIsVector |= VT.isVector(); + RHSIsFP |= VT.isFloatingPoint(); + } + } + + if (LHSIsVector != RHSIsVector) + return RHSIsVector; + + if (LHSIsFP != RHSIsFP) + return RHSIsFP; // Otherwise, if the patterns might both match, sort based on complexity, // which means that we prefer to match patterns that cover more nodes in the diff --git a/llvm/utils/TableGen/DAGISelMatcher.cpp b/llvm/utils/TableGen/DAGISelMatcher.cpp index 3ec20e318f680..a68ebf3551cf3 100644 --- a/llvm/utils/TableGen/DAGISelMatcher.cpp +++ b/llvm/utils/TableGen/DAGISelMatcher.cpp @@ -180,8 +180,7 @@ void SwitchOpcodeMatcher::printImpl(raw_ostream &OS, indent Indent) const { } void CheckTypeMatcher::printImpl(raw_ostream &OS, indent Indent) const { - OS << Indent << "CheckType " << getEnumName(Type) << ", ResNo=" << ResNo - << '\n'; + OS << Indent << "CheckType " << Type << ", ResNo=" << ResNo << '\n'; } void SwitchTypeMatcher::printImpl(raw_ostream &OS, indent Indent) const { @@ -194,8 +193,7 @@ void SwitchTypeMatcher::printImpl(raw_ostream &OS, indent Indent) const { } void CheckChildTypeMatcher::printImpl(raw_ostream &OS, indent Indent) const { - OS << Indent << "CheckChildType " << ChildNo << " " << getEnumName(Type) - << '\n'; + OS << Indent << "CheckChildType " << ChildNo << " " << Type << '\n'; } void CheckIntegerMatcher::printImpl(raw_ostream &OS, indent Indent) const { @@ -245,7 +243,7 @@ void CheckImmAllZerosVMatcher::printImpl(raw_ostream &OS, indent Indent) const { } void EmitIntegerMatcher::printImpl(raw_ostream &OS, indent Indent) const { - OS << Indent << "EmitInteger " << Val << " VT=" << getEnumName(VT) << '\n'; + OS << Indent << "EmitInteger " << Val << " VT=" << VT << '\n'; } void EmitRegisterMatcher::printImpl(raw_ostream &OS, indent Indent) const { @@ -254,7 +252,7 @@ void EmitRegisterMatcher::printImpl(raw_ostream &OS, indent Indent) const { OS << Reg->getName(); else OS << "zero_reg"; - OS << " VT=" << getEnumName(VT) << '\n'; + OS << " VT=" << VT << '\n'; } void EmitConvertToTargetMatcher::printImpl(raw_ostream &OS, @@ -281,8 +279,8 @@ void EmitNodeMatcherCommon::printImpl(raw_ostream &OS, indent Indent) const { OS << (isa(this) ? "MorphNodeTo: " : "EmitNode: ") << CGI.Namespace << "::" << CGI.getName() << ": "; - for (MVT VT : VTs) - OS << ' ' << getEnumName(VT); + for (const ValueTypeByHwMode &VT : VTs) + OS << ' ' << VT; OS << '('; for (unsigned Operand : Operands) OS << Operand << ' '; @@ -316,19 +314,26 @@ void MorphNodeToMatcher::anchor() {} // isContradictoryImpl Implementations. -static bool TypesAreContradictory(MVT T1, MVT T2) { +static bool TypesAreContradictory(const ValueTypeByHwMode &VT1, + const ValueTypeByHwMode &VT2) { // If the two types are the same, then they are the same, so they don't // contradict. - if (T1 == T2) + if (VT1 == VT2) return false; + if (!VT1.isSimple() || !VT2.isSimple()) + return false; + + MVT T1 = VT1.getSimple(); + MVT T2 = VT2.getSimple(); + if (T1 == MVT::pAny) - return TypesAreContradictory(MVT::iPTR, T2) && - TypesAreContradictory(MVT::cPTR, T2); + return TypesAreContradictory(MVT(MVT::iPTR), T2) && + TypesAreContradictory(MVT(MVT::cPTR), T2); if (T2 == MVT::pAny) - return TypesAreContradictory(T1, MVT::iPTR) && - TypesAreContradictory(T1, MVT::cPTR); + return TypesAreContradictory(T1, MVT(MVT::iPTR)) && + TypesAreContradictory(T1, MVT(MVT::cPTR)); // If either type is about iPtr, then they don't conflict unless the other // one is not a scalar integer type. diff --git a/llvm/utils/TableGen/DAGISelMatcher.h b/llvm/utils/TableGen/DAGISelMatcher.h index 192d5c47d3489..f2a75147c2aca 100644 --- a/llvm/utils/TableGen/DAGISelMatcher.h +++ b/llvm/utils/TableGen/DAGISelMatcher.h @@ -9,6 +9,7 @@ #ifndef LLVM_UTILS_TABLEGEN_COMMON_DAGISELMATCHER_H #define LLVM_UTILS_TABLEGEN_COMMON_DAGISELMATCHER_H +#include "Common/InfoByHwMode.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -516,14 +517,14 @@ class SwitchOpcodeMatcher : public Matcher { /// CheckTypeMatcher - This checks to see if the current node has the /// specified type at the specified result, if not it fails to match. class CheckTypeMatcher : public Matcher { - MVT Type; + ValueTypeByHwMode Type; unsigned ResNo; public: - CheckTypeMatcher(MVT type, unsigned resno) - : Matcher(CheckType), Type(type), ResNo(resno) {} + CheckTypeMatcher(ValueTypeByHwMode type, unsigned resno) + : Matcher(CheckType), Type(std::move(type)), ResNo(resno) {} - MVT getType() const { return Type; } + const ValueTypeByHwMode &getType() const { return Type; } unsigned getResNo() const { return ResNo; } static bool classof(const Matcher *N) { return N->getKind() == CheckType; } @@ -565,14 +566,14 @@ class SwitchTypeMatcher : public Matcher { /// specified type, if not it fails to match. class CheckChildTypeMatcher : public Matcher { unsigned ChildNo; - MVT Type; + ValueTypeByHwMode Type; public: - CheckChildTypeMatcher(unsigned childno, MVT type) - : Matcher(CheckChildType), ChildNo(childno), Type(type) {} + CheckChildTypeMatcher(unsigned childno, ValueTypeByHwMode type) + : Matcher(CheckChildType), ChildNo(childno), Type(std::move(type)) {} unsigned getChildNo() const { return ChildNo; } - MVT getType() const { return Type; } + const ValueTypeByHwMode &getType() const { return Type; } static bool classof(const Matcher *N) { return N->getKind() == CheckChildType; @@ -831,22 +832,20 @@ class EmitIntegerMatcher : public Matcher { // Optional string to give the value a symbolic name for readability. std::string Str; int64_t Val; - MVT VT; + ValueTypeByHwMode VT; unsigned ResultNo; public: - EmitIntegerMatcher(int64_t val, MVT vt, unsigned resultNo) - : Matcher(EmitInteger), - Val(SignExtend64(val, MVT(vt).getFixedSizeInBits())), VT(vt), - ResultNo(resultNo) {} + EmitIntegerMatcher(int64_t val, ValueTypeByHwMode vt, unsigned resultNo) + : Matcher(EmitInteger), Val(val), VT(std::move(vt)), ResultNo(resultNo) {} EmitIntegerMatcher(const std::string &str, int64_t val, MVT vt, unsigned resultNo) : Matcher(EmitInteger), Str(str), Val(val), VT(vt), ResultNo(resultNo) {} const std::string &getString() const { return Str; } int64_t getValue() const { return Val; } - MVT getVT() const { return VT; } + const ValueTypeByHwMode &getVT() const { return VT; } unsigned getResultNo() const { return ResultNo; } static bool classof(const Matcher *N) { return N->getKind() == EmitInteger; } @@ -865,16 +864,18 @@ class EmitRegisterMatcher : public Matcher { /// Reg - The def for the register that we're emitting. If this is null, then /// this is a reference to zero_reg. const CodeGenRegister *Reg; - MVT VT; + ValueTypeByHwMode VT; unsigned ResultNo; public: - EmitRegisterMatcher(const CodeGenRegister *reg, MVT vt, unsigned resultNo) - : Matcher(EmitRegister), Reg(reg), VT(vt), ResultNo(resultNo) {} + EmitRegisterMatcher(const CodeGenRegister *reg, ValueTypeByHwMode vt, + unsigned resultNo) + : Matcher(EmitRegister), Reg(reg), VT(std::move(vt)), ResultNo(resultNo) { + } const CodeGenRegister *getReg() const { return Reg; } - MVT getVT() const { return VT; } + const ValueTypeByHwMode &getVT() const { return VT; } unsigned getResultNo() const { return ResultNo; } static bool classof(const Matcher *N) { return N->getKind() == EmitRegister; } @@ -1002,7 +1003,7 @@ class EmitNodeXFormMatcher : public Matcher { /// MorphNodeTo. class EmitNodeMatcherCommon : public Matcher { const CodeGenInstruction &CGI; - const SmallVector VTs; + const SmallVector VTs; const SmallVector Operands; bool HasChain, HasInGlue, HasOutGlue, HasMemRefs; @@ -1012,7 +1013,8 @@ class EmitNodeMatcherCommon : public Matcher { int NumFixedArityOperands; public: - EmitNodeMatcherCommon(const CodeGenInstruction &cgi, ArrayRef vts, + EmitNodeMatcherCommon(const CodeGenInstruction &cgi, + ArrayRef vts, ArrayRef operands, bool hasChain, bool hasInGlue, bool hasOutGlue, bool hasmemrefs, int numfixedarityoperands, bool isMorphNodeTo) @@ -1024,7 +1026,7 @@ class EmitNodeMatcherCommon : public Matcher { const CodeGenInstruction &getInstruction() const { return CGI; } unsigned getNumVTs() const { return VTs.size(); } - MVT getVT(unsigned i) const { + const ValueTypeByHwMode &getVT(unsigned i) const { assert(i < VTs.size()); return VTs[i]; } @@ -1035,8 +1037,8 @@ class EmitNodeMatcherCommon : public Matcher { return Operands[i]; } - const SmallVectorImpl &getVTList() const { return VTs; } - const SmallVectorImpl &getOperandList() const { return Operands; } + ArrayRef getVTList() const { return VTs; } + ArrayRef getOperandList() const { return Operands; } bool hasChain() const { return HasChain; } bool hasInGlue() const { return HasInGlue; } @@ -1059,9 +1061,10 @@ class EmitNodeMatcher : public EmitNodeMatcherCommon { unsigned FirstResultSlot; public: - EmitNodeMatcher(const CodeGenInstruction &cgi, ArrayRef vts, - ArrayRef operands, bool hasChain, bool hasInGlue, - bool hasOutGlue, bool hasmemrefs, int numfixedarityoperands, + EmitNodeMatcher(const CodeGenInstruction &cgi, + ArrayRef vts, ArrayRef operands, + bool hasChain, bool hasInGlue, bool hasOutGlue, + bool hasmemrefs, int numfixedarityoperands, unsigned firstresultslot) : EmitNodeMatcherCommon(cgi, vts, operands, hasChain, hasInGlue, hasOutGlue, hasmemrefs, numfixedarityoperands, @@ -1078,7 +1081,8 @@ class MorphNodeToMatcher : public EmitNodeMatcherCommon { const PatternToMatch &Pattern; public: - MorphNodeToMatcher(const CodeGenInstruction &cgi, ArrayRef vts, + MorphNodeToMatcher(const CodeGenInstruction &cgi, + ArrayRef vts, ArrayRef operands, bool hasChain, bool hasInGlue, bool hasOutGlue, bool hasmemrefs, int numfixedarityoperands, const PatternToMatch &pattern) diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp index 3d573483afebf..95338b906568a 100644 --- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp @@ -71,6 +71,8 @@ class MatcherTableEmitter { std::vector VecIncludeStrings; MapVector> VecPatterns; + std::map ValueTypeMap; + unsigned getPatternIdxFromTable(std::string &&P, std::string &&include_loc) { const auto [It, Inserted] = VecPatterns.try_emplace(std::move(P), VecPatterns.size()); @@ -174,6 +176,8 @@ class MatcherTableEmitter { void EmitPredicateFunctions(raw_ostream &OS); + void EmitValueTypeFunction(raw_ostream &OS); + void EmitHistogram(const Matcher *N, raw_ostream &OS); void EmitPatternMatchTable(raw_ostream &OS); @@ -212,6 +216,21 @@ class MatcherTableEmitter { } return Entry - 1; } + + unsigned getValueTypeID(const ValueTypeByHwMode &VT) { + unsigned &Entry = ValueTypeMap[VT]; + if (Entry == 0) { + Entry = ValueTypeMap.size(); + if (Entry > 256) + report_fatal_error( + "More ValueType by HwMode than fit in a 8-bit index"); + } + + return Entry - 1; + } + + unsigned emitValueTypeByHwMode(const ValueTypeByHwMode &VTBH, + raw_ostream &OS); }; } // end anonymous namespace. @@ -436,6 +455,14 @@ static unsigned emitMVT(MVT VT, raw_ostream &OS) { return EmitVBRValue(VT.SimpleTy, OS); } +unsigned +MatcherTableEmitter::emitValueTypeByHwMode(const ValueTypeByHwMode &VTBH, + raw_ostream &OS) { + if (!OmitComments) + OS << "/*" << VTBH << "*/"; + OS << getValueTypeID(VTBH) << ','; + return 1; +} /// EmitMatcher - Emit bytes for the specified matcher and return /// the number of bytes emitted. unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, @@ -668,39 +695,66 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, } case Matcher::CheckType: { + const ValueTypeByHwMode &VTBH = cast(N)->getType(); + if (VTBH.isSimple()) { + MVT VT = VTBH.getSimple(); + if (cast(N)->getResNo() == 0) { + switch (VT.SimpleTy) { + case MVT::i32: + case MVT::i64: + OS << "OPC_CheckTypeI" << MVT(VT).getSizeInBits() << ",\n"; + return 1; + default: + OS << "OPC_CheckType, "; + unsigned NumBytes = emitMVT(VT, OS); + OS << '\n'; + return NumBytes + 1; + } + } + + OS << "OPC_CheckTypeRes, " << cast(N)->getResNo() + << ", "; + unsigned NumBytes = + emitMVT(cast(N)->getType().getSimple(), OS); + OS << '\n'; + return NumBytes + 2; + } + + unsigned OpSize = 1; if (cast(N)->getResNo() == 0) { - MVT VT = cast(N)->getType(); + OS << "OPC_CheckTypeByHwMode, "; + } else { + OS << "OPC_CheckTypeResByHwMode, " + << cast(N)->getResNo() << ", "; + OpSize += 1; + } + OpSize += emitValueTypeByHwMode(VTBH, OS); + OS << '\n'; + return OpSize; + } + + case Matcher::CheckChildType: { + const ValueTypeByHwMode &VTBH = cast(N)->getType(); + if (VTBH.isSimple()) { + MVT VT = VTBH.getSimple(); switch (VT.SimpleTy) { case MVT::i32: case MVT::i64: - OS << "OPC_CheckTypeI" << MVT(VT).getSizeInBits() << ",\n"; + OS << "OPC_CheckChild" << cast(N)->getChildNo() + << "TypeI" << MVT(VT).getSizeInBits() << ",\n"; return 1; default: - OS << "OPC_CheckType, "; + OS << "OPC_CheckChild" << cast(N)->getChildNo() + << "Type, "; unsigned NumBytes = emitMVT(VT, OS); - OS << "\n"; + OS << '\n'; return NumBytes + 1; } - } - OS << "OPC_CheckTypeRes, " << cast(N)->getResNo() << ", "; - unsigned NumBytes = emitMVT(cast(N)->getType(), OS); - OS << "\n"; - return NumBytes + 2; - } - - case Matcher::CheckChildType: { - MVT VT = cast(N)->getType(); - switch (VT.SimpleTy) { - case MVT::i32: - case MVT::i64: - OS << "OPC_CheckChild" << cast(N)->getChildNo() - << "TypeI" << MVT(VT).getSizeInBits() << ",\n"; - return 1; - default: + } else { OS << "OPC_CheckChild" << cast(N)->getChildNo() - << "Type, "; - unsigned NumBytes = emitMVT(VT, OS); - OS << "\n"; + << "TypeByHwMode, "; + unsigned NumBytes = emitValueTypeByHwMode(VTBH, OS); + OS << '\n'; return NumBytes + 1; } } @@ -793,20 +847,27 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, const auto *IM = cast(N); int64_t Val = IM->getValue(); const std::string &Str = IM->getString(); - MVT VT = IM->getVT(); + const ValueTypeByHwMode &VTBH = IM->getVT(); unsigned TypeBytes = 0; - switch (VT.SimpleTy) { - case MVT::i8: - case MVT::i16: - case MVT::i32: - case MVT::i64: - OS << "OPC_EmitIntegerI" << VT.getSizeInBits() << ", "; - break; - default: - OS << "OPC_EmitInteger, "; - TypeBytes = emitMVT(VT, OS); + if (VTBH.isSimple()) { + MVT VT = VTBH.getSimple(); + switch (VT.SimpleTy) { + case MVT::i8: + case MVT::i16: + case MVT::i32: + case MVT::i64: + OS << "OPC_EmitIntegerI" << VT.getSizeInBits() << ", "; + break; + default: + OS << "OPC_EmitInteger, "; + TypeBytes = emitMVT(VT, OS); + OS << ' '; + break; + } + } else { + OS << "OPC_EmitIntegerByHwMode, "; + TypeBytes = emitValueTypeByHwMode(VTBH, OS); OS << ' '; - break; } // If the value is 63 or smaller, use the string directly. Otherwise, use // a VBR. @@ -829,26 +890,41 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, case Matcher::EmitRegister: { const EmitRegisterMatcher *Matcher = cast(N); const CodeGenRegister *Reg = Matcher->getReg(); - MVT VT = Matcher->getVT(); + const ValueTypeByHwMode &VTBH = Matcher->getVT(); unsigned OpBytes; - // If the enum value of the register is larger than one byte can handle, - // use EmitRegister2. - if (Reg && Reg->EnumValue > 255) { - OS << "OPC_EmitRegister2, "; - OpBytes = emitMVT(VT, OS); - OS << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n"; - return OpBytes + 3; - } - switch (VT.SimpleTy) { - case MVT::i32: - case MVT::i64: - OpBytes = 1; - OS << "OPC_EmitRegisterI" << VT.getSizeInBits() << ", "; - break; - default: - OS << "OPC_EmitRegister, "; - OpBytes = emitMVT(VT, OS) + 1; - break; + if (VTBH.isSimple()) { + MVT VT = VTBH.getSimple(); + // If the enum value of the register is larger than one byte can handle, + // use EmitRegister2. + if (Reg && Reg->EnumValue > 255) { + OS << "OPC_EmitRegister2, "; + OpBytes = emitMVT(VT, OS); + OS << " TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n"; + return OpBytes + 3; + } + switch (VT.SimpleTy) { + case MVT::i32: + case MVT::i64: + OpBytes = 1; + OS << "OPC_EmitRegisterI" << VT.getSizeInBits() << ", "; + break; + default: + OS << "OPC_EmitRegister, "; + OpBytes = emitMVT(VT, OS) + 1; + OS << ' '; + break; + } + } else { + if (Reg && Reg->EnumValue > 255) { + OS << "OPC_EmitRegisterByHwMode2, "; + OpBytes = emitValueTypeByHwMode(VTBH, OS); + OS << " TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n"; + return OpBytes + 3; + } + + OS << "OPC_EmitRegisterByHwMode, "; + OpBytes = emitValueTypeByHwMode(VTBH, OS) + 1; + OS << ' '; } if (Reg) { OS << getQualifiedName(Reg->TheDef); @@ -955,10 +1031,16 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, OS << "OPC_CaptureDeactivationSymbol,\n"; OS.indent(FullIndexWidth + Indent); } + + bool ByHwMode = + llvm::any_of(EN->getVTList(), [](const ValueTypeByHwMode &VT) { + return !VT.isSimple(); + }); + bool IsEmitNode = isa(EN); OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo"); unsigned NumVTs = EN->getNumVTs(); - bool CompressVTs = NumVTs < 3; + bool CompressVTs = !ByHwMode && EN->getNumVTs() < 3; bool CompressNodeInfo = false; if (CompressVTs) { OS << NumVTs; @@ -987,6 +1069,9 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, } } + if (ByHwMode) + OS << "ByHwMode"; + const CodeGenInstruction &CGI = EN->getInstruction(); OS << ", TARGET_VAL(" << CGI.Namespace << "::" << CGI.TheDef->getName() << ")"; @@ -1014,9 +1099,17 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, OS << ","; } unsigned NumTypeBytes = 0; - for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i) { - OS << ' '; - NumTypeBytes += emitMVT(EN->getVT(i), OS); + if (ByHwMode) { + for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i) { + OS << ' '; + const ValueTypeByHwMode &VTBH = EN->getVT(i); + NumTypeBytes += emitValueTypeByHwMode(VTBH, OS); + } + } else { + for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i) { + OS << ' '; + NumTypeBytes += emitMVT(EN->getVT(i).getSimple(), OS); + } } OS << ' ' << EN->getNumOperands(); @@ -1261,6 +1354,40 @@ void MatcherTableEmitter::EmitPredicateFunctions(raw_ostream &OS) { } } +void MatcherTableEmitter::EmitValueTypeFunction(raw_ostream &OS) { + if (ValueTypeMap.empty()) + return; + + BeginEmitFunction(OS, "MVT", "getValueTypeForHwMode(unsigned Index) const", + /*AddOverride=*/true); + OS << "{\n"; + + OS << " switch (Index) {\n"; + OS << " default: llvm_unreachable(\"Unexpected index\");\n"; + + for (const auto &[VTs, Idx] : ValueTypeMap) { + OS << " case " << (Idx - 1) << ":\n"; + OS << " switch (HwMode) {\n"; + if (!VTs.hasDefault()) + OS << " default:\n return MVT();\n"; + for (const auto [Mode, VT] : VTs) { + if (Mode == DefaultMode) + OS << " default:\n"; + else + OS << " case " << Mode << ":\n"; + OS << " return " << getEnumName(VT) << ";\n"; + } + + OS << " }\n"; + OS << " break;\n"; + } + + OS << " }\n"; + + OS << "}\n"; + EndEmitFunction(OS); +} + static StringRef getOpcodeString(Matcher::KindTy Kind) { switch (Kind) { case Matcher::Scope: @@ -1423,6 +1550,8 @@ void llvm::EmitMatcherTable(Matcher *TheMatcher, const CodeGenDAGPatterns &CGP, // Next up, emit the function for node and pattern predicates: MatcherEmitter.EmitPredicateFunctions(OS); + MatcherEmitter.EmitValueTypeFunction(OS); + if (InstrumentCoverage) MatcherEmitter.EmitPatternMatchTable(OS); diff --git a/llvm/utils/TableGen/DAGISelMatcherGen.cpp b/llvm/utils/TableGen/DAGISelMatcherGen.cpp index fc0564b18a356..e8f146264b1ec 100644 --- a/llvm/utils/TableGen/DAGISelMatcherGen.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherGen.cpp @@ -536,7 +536,7 @@ void MatcherGen::EmitMatchCode(const TreePatternNode &N, } for (unsigned I : ResultsToTypeCheck) - AddMatcher(new CheckTypeMatcher(N.getSimpleType(I), I)); + AddMatcher(new CheckTypeMatcher(N.getType(I), I)); } /// EmitMatcherCode - Generate the code that matches the predicate of this @@ -660,7 +660,7 @@ void MatcherGen::EmitResultLeafAsOperand(const TreePatternNode &N, assert(N.isLeaf() && "Must be a leaf"); if (const IntInit *II = dyn_cast(N.getLeafValue())) { - AddMatcher(new EmitIntegerMatcher(II->getValue(), N.getSimpleType(0), + AddMatcher(new EmitIntegerMatcher(II->getValue(), N.getType(0), NextRecordedOperandNo)); ResultOps.push_back(NextRecordedOperandNo++); return; @@ -671,21 +671,21 @@ void MatcherGen::EmitResultLeafAsOperand(const TreePatternNode &N, const Record *Def = DI->getDef(); if (Def->isSubClassOf("Register")) { const CodeGenRegister *Reg = CGP.getTargetInfo().getRegBank().getReg(Def); - AddMatcher(new EmitRegisterMatcher(Reg, N.getSimpleType(0), - NextRecordedOperandNo)); + AddMatcher( + new EmitRegisterMatcher(Reg, N.getType(0), NextRecordedOperandNo)); ResultOps.push_back(NextRecordedOperandNo++); return; } if (Def->getName() == "zero_reg") { - AddMatcher(new EmitRegisterMatcher(nullptr, N.getSimpleType(0), + AddMatcher(new EmitRegisterMatcher(nullptr, N.getType(0), NextRecordedOperandNo)); ResultOps.push_back(NextRecordedOperandNo++); return; } if (Def->getName() == "undef_tied_input") { - MVT ResultVT = N.getSimpleType(0); + ValueTypeByHwMode ResultVT = N.getType(0); auto IDOperandNo = NextRecordedOperandNo++; const Record *ImpDef = Def->getRecords().getDef("IMPLICIT_DEF"); const CodeGenInstruction &II = CGP.getTargetInfo().getInstruction(ImpDef); @@ -879,9 +879,9 @@ void MatcherGen::EmitResultInstructionAsOperand( // Result order: node results, chain, glue // Determine the result types. - SmallVector ResultVTs; + SmallVector ResultVTs; for (unsigned i = 0, e = N.getNumTypes(); i != e; ++i) - ResultVTs.push_back(N.getSimpleType(i)); + ResultVTs.push_back(N.getType(i)); // If this is the root instruction of a pattern that has physical registers in // its result pattern, add output VTs for them. For example, X86 has: @@ -956,8 +956,9 @@ void MatcherGen::EmitResultInstructionAsOperand( NumFixedArityOperands, NextRecordedOperandNo)); // The non-chain and non-glue results of the newly emitted node get recorded. - for (MVT ResultVT : ResultVTs) { - if (ResultVT == MVT::Other || ResultVT == MVT::Glue) + for (const ValueTypeByHwMode &ResultVT : ResultVTs) { + if (ResultVT.isSimple() && (ResultVT.getSimple() == MVT::Other || + ResultVT.getSimple() == MVT::Glue)) break; OutputOps.push_back(NextRecordedOperandNo++); } diff --git a/llvm/utils/TableGen/DAGISelMatcherOpt.cpp b/llvm/utils/TableGen/DAGISelMatcherOpt.cpp index 222b73ddd8bef..9a40ecd30edff 100644 --- a/llvm/utils/TableGen/DAGISelMatcherOpt.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherOpt.cpp @@ -282,7 +282,7 @@ static void ContractNodes(std::unique_ptr &InputMatcherPtr, #endif if (ResultsMatch) { - ArrayRef VTs = EN->getVTList(); + ArrayRef VTs = EN->getVTList(); ArrayRef Operands = EN->getOperandList(); MatcherPtr->reset(new MorphNodeToMatcher( EN->getInstruction(), VTs, Operands, EN->hasChain(), @@ -518,10 +518,11 @@ static void FactorScope(std::unique_ptr &MatcherPtr) { if (AllTypeChecks) { CheckTypeMatcher *CTM = cast_or_null( FindNodeWithKind(Optn, Matcher::CheckType)); - if (!CTM || + if (!CTM || !CTM->getType().isSimple() || // iPTR/cPTR checks could alias any other case without us knowing, // don't bother with them. - CTM->getType() == MVT::iPTR || CTM->getType() == MVT::cPTR || + CTM->getType().getSimple() == MVT::iPTR || + CTM->getType().getSimple() == MVT::cPTR || // SwitchType only works for result #0. CTM->getResNo() != 0 || // If the CheckType isn't at the start of the list, see if we can move @@ -563,7 +564,7 @@ static void FactorScope(std::unique_ptr &MatcherPtr) { auto *CTM = cast(M); Matcher *MatcherWithoutCTM = Optn->unlinkNode(CTM); - MVT CTMTy = CTM->getType(); + MVT CTMTy = CTM->getType().getSimple(); delete CTM; unsigned &Entry = TypeEntry[CTMTy.SimpleTy];