-
Notifications
You must be signed in to change notification settings - Fork 29
[AIE2][AIE2P] Add combiners related to extract/insert/broadcast #723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,95 @@ cl::opt<bool> MemsetOptimizations( | |
| "aie-optimize-memsets", cl::init(true), cl::Hidden, | ||
| cl::desc("Apply memset optimizations (peeling/align/etc.).")); | ||
|
|
||
| namespace { | ||
|
|
||
| bool isGenericExtractOpcode(unsigned Opc, const AIEBaseInstrInfo &TII) { | ||
| // Check if it's either SEXT or ZEXT extract | ||
| const unsigned ExtractSextOpc = TII.getGenericExtractVectorEltOpcode(true); | ||
| if (Opc == ExtractSextOpc) { | ||
| return true; | ||
| } | ||
| const unsigned ExtractZextOpc = TII.getGenericExtractVectorEltOpcode(false); | ||
| return Opc == ExtractZextOpc; | ||
| } | ||
|
|
||
| /// We conservatively implement only known cases. | ||
| bool mayMIShiftElements(const MachineInstr *MI) { | ||
| switch (MI->getOpcode()) { | ||
| case TargetOpcode::G_FMUL: | ||
| case TargetOpcode::G_FADD: | ||
| case TargetOpcode::G_FSUB: | ||
| return false; | ||
| case TargetOpcode::G_INTRINSIC: | ||
| case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: { | ||
| switch (cast<GIntrinsic>(MI)->getIntrinsicID()) { | ||
| case Intrinsic::aie2_v16accfloat_to_v16bf16: | ||
| case Intrinsic::aie2p_v16accfloat_to_v16bf16: | ||
| case Intrinsic::aie2p_v32accfloat_to_v32bf16: | ||
| case Intrinsic::aie2p_I512_I512_ACC1024_bf_mul_conf: | ||
| return false; | ||
| } | ||
| } | ||
| default: | ||
| return true; | ||
| } | ||
| } | ||
|
|
||
| /// Verify that all uses of a broadcast vector through a chain of operations | ||
| /// only extract from position 0. The chain may include G_CONCAT_VECTORS, | ||
| /// G_UNMERGE_VALUES, and vector operations. | ||
| /// \param Reg The register to verify uses for | ||
| /// \param MRI Machine register info | ||
| /// \param TII Target instruction info | ||
| /// \return true if all uses only extract position 0 | ||
| bool verifyBroadcastUsesOnlyExtractZero(Register Reg, MachineRegisterInfo &MRI, | ||
| const AIEBaseInstrInfo &TII) { | ||
| if (!MRI.hasOneNonDBGUser(Reg)) | ||
| return false; | ||
|
|
||
| MachineInstr *UserMI = &*MRI.use_nodbg_instructions(Reg).begin(); | ||
| unsigned Opcode = UserMI->getOpcode(); | ||
|
|
||
| // For concat, Reg should be the first src operand. | ||
| if (Opcode == TargetOpcode::G_CONCAT_VECTORS) { | ||
| if (UserMI->getOperand(1).getReg() != Reg) | ||
| return false; | ||
| return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(), | ||
| MRI, TII); | ||
| // For unmerge, the useful operand should be the first one, | ||
| // the other ones, they should be dead. | ||
| } else if (Opcode == TargetOpcode::G_UNMERGE_VALUES) { | ||
| unsigned OpCount = 0; | ||
| for (auto &MO : UserMI->defs()) { | ||
| Register DefReg = MO.getReg(); | ||
| if (OpCount == 0 && !MRI.hasOneUse(DefReg)) | ||
| return false; | ||
| else if (OpCount && !MRI.use_empty(DefReg)) | ||
| return false; | ||
| OpCount++; | ||
| } | ||
| return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(), | ||
| MRI, TII); | ||
| // If we extract from zero, we succeed, otherwise we fail. | ||
| } else if (isGenericExtractOpcode(Opcode, TII)) { | ||
| const Register UseIdxReg = UserMI->getOperand(2).getReg(); | ||
| auto UseIdx = getIConstantVRegValWithLookThrough(UseIdxReg, MRI); | ||
| return UseIdx && UseIdx->Value.getZExtValue() == 0; | ||
| // If we bitcast, we may need other lanes. | ||
| } else if (Opcode == TargetOpcode::G_BITCAST) { | ||
| return false; | ||
| } else { | ||
| if (mayMIShiftElements(UserMI)) | ||
| return false; | ||
| return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(), | ||
| MRI, TII); | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| static unsigned getNumMaskUndefs(const ArrayRef<int> &Mask, | ||
| unsigned StartIndex) { | ||
| unsigned Count = 0; | ||
|
|
@@ -4226,8 +4315,7 @@ namespace { | |
| MachineInstr *getBcstFeedByAssertExtVecExtr(MachineInstr &MI, | ||
| MachineRegisterInfo &MRI, | ||
| const AIEBaseInstrInfo &TII) { | ||
| assert(MI.getOpcode() == TII.getGenericExtractVectorEltOpcode(false) || | ||
| MI.getOpcode() == TII.getGenericExtractVectorEltOpcode(true)); | ||
| assert(isGenericExtractOpcode(MI.getOpcode(), TII)); | ||
|
|
||
| /// Get single NonDebug User of \p MI with the opcode \p UseMIOpcode | ||
| auto GetSingleNonDbgUser = [&MRI](MachineInstr &MI, | ||
|
|
@@ -4266,8 +4354,7 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI, | |
| const AIEBaseInstrInfo &TII, | ||
| GISelChangeObserver &Observer, | ||
| BuildFnTy &MatchInfo) { | ||
| assert((MI.getOpcode() == TII.getGenericExtractVectorEltOpcode(false) || | ||
| MI.getOpcode() == TII.getGenericExtractVectorEltOpcode(true)) && | ||
| assert(isGenericExtractOpcode(MI.getOpcode(), TII) && | ||
| "Expected a extract_vector_elt"); | ||
| const MachineInstr *BcstMI = getBcstFeedByAssertExtVecExtr(MI, MRI, TII); | ||
| if (!BcstMI) | ||
|
|
@@ -4323,3 +4410,129 @@ bool llvm::matchMsbScalar(Register ScalarReg, Register BroadcastReg, | |
|
|
||
| return false; | ||
| } | ||
|
|
||
| /// Match a pattern where: | ||
| /// %18:_(<16 x s32>) = COPY $x0 | ||
| /// %10:_(<16 x s32>) = G_IMPLICIT_DEF | ||
| /// %9:_(s32) = G_CONSTANT i32 0 | ||
| /// %8:_(s32) = G_AIE_SEXT_EXTRACT_VECTOR_ELT %18(<16 x s32>), %9(s32) | ||
| /// %22:_(<16 x s32>) = G_AIE_INSERT_VECTOR_ELT %10, %8(s32), %9(s32) | ||
| /// | ||
| /// This can be simplified to: | ||
| /// %22:_(<16 x s32>) = COPY %18 | ||
| bool llvm::matchInsertExtractVectorEltToCopy(MachineInstr &MI, | ||
| MachineRegisterInfo &MRI, | ||
| const AIEBaseInstrInfo &TII, | ||
| BuildFnTy &MatchInfo) { | ||
| assert(MI.getOpcode() == TII.getGenericInsertVectorEltOpcode() && | ||
| "Expected G_AIE_INSERT_VECTOR_ELT"); | ||
|
|
||
| // Get the insert operands | ||
| const Register InsertDstReg = MI.getOperand(0).getReg(); | ||
| const Register InsertSrcVecReg = MI.getOperand(1).getReg(); | ||
| const Register InsertedEltReg = MI.getOperand(2).getReg(); | ||
| const Register InsertIdxReg = MI.getOperand(3).getReg(); | ||
|
|
||
| // Check that the insert source vector is G_IMPLICIT_DEF | ||
| const MachineInstr *InsertSrcMI = MRI.getVRegDef(InsertSrcVecReg); | ||
| if (!InsertSrcMI || InsertSrcMI->getOpcode() != TargetOpcode::G_IMPLICIT_DEF) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I have been using broadcast instead of insert into implicit def for the FMUL implementation. It will not work there.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Humm, I was working on the FADD mir. It looks we use different approaches in different legalizations. |
||
| return false; | ||
|
|
||
| // Get the definition of the inserted element | ||
| const MachineInstr *ExtractMI = MRI.getVRegDef(InsertedEltReg); | ||
| if (!ExtractMI) | ||
| return false; | ||
|
|
||
| // Check if it's either SEXT or ZEXT extract | ||
| if (!isGenericExtractOpcode(ExtractMI->getOpcode(), TII)) | ||
| return false; | ||
|
|
||
| // Get extract operands | ||
| const Register ExtractSrcVecReg = ExtractMI->getOperand(1).getReg(); | ||
| const Register ExtractIdxReg = ExtractMI->getOperand(2).getReg(); | ||
|
|
||
| // Verify that the insert destination vector type matches the extract source | ||
| // vector type | ||
| const LLT InsertDstTy = MRI.getType(InsertDstReg); | ||
| const LLT ExtractSrcTy = MRI.getType(ExtractSrcVecReg); | ||
|
|
||
| if (InsertDstTy != ExtractSrcTy) | ||
| return false; | ||
|
|
||
| // Check that insert and extract indices are the same | ||
| // They can be the same register, or both constants with the same value | ||
| if (InsertIdxReg != ExtractIdxReg) { | ||
| auto InsertIdxCst = getIConstantVRegValWithLookThrough(InsertIdxReg, MRI); | ||
| auto ExtractIdxCst = getIConstantVRegValWithLookThrough(ExtractIdxReg, MRI); | ||
| if (!InsertIdxCst || !ExtractIdxCst || | ||
| InsertIdxCst->Value != ExtractIdxCst->Value) | ||
| return false; | ||
| } | ||
|
|
||
| // Copy the extract source vector (the real vector) to the insert destination | ||
| MatchInfo = [=](MachineIRBuilder &B) { | ||
| B.buildCopy(InsertDstReg, ExtractSrcVecReg); | ||
| }; | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| /// Match a pattern where a broadcast is fed by an extract from position 0, | ||
| /// and all uses of the broadcast through a chain of operations only extract | ||
| /// from position 0. This allows us to replace the broadcast with a copy of | ||
| /// the original vector. | ||
| /// | ||
| /// Pattern: | ||
| /// %200:_(s32) = G_AIE_SEXT_EXTRACT_VECTOR_ELT %50(<16 x s32>), %3(s32) // pos | ||
| /// 0 %5:_(<16 x s32>) = G_AIE_BROADCAST_VECTOR %200(s32) | ||
| /// ... (chain of concat/unmerge/vector ops) | ||
| /// %2:_(s32) = G_AIE_SEXT_EXTRACT_VECTOR_ELT %result(<16 x s32>), %3(s32) // | ||
| /// pos 0 | ||
| /// | ||
| /// Transforms to: | ||
| /// %200:_(s32) = G_AIE_SEXT_EXTRACT_VECTOR_ELT %50(<16 x s32>), %3(s32) | ||
| /// %5:_(<16 x s32>) = COPY %50(<16 x s32>) // Copy source vector instead of | ||
| /// broadcast | ||
| /// ... (chain of operations) | ||
| /// %2:_(s32) = G_AIE_SEXT_EXTRACT_VECTOR_ELT %result(<16 x s32>), %3(s32) | ||
| bool llvm::matchBroadcastExtractToCopy(MachineInstr &MI, | ||
| MachineRegisterInfo &MRI, | ||
| const AIEBaseInstrInfo &TII, | ||
| BuildFnTy &MatchInfo) { | ||
| assert(MI.getOpcode() == TII.getGenericBroadcastVectorOpcode() && | ||
| "Expected G_AIE_BROADCAST_VECTOR"); | ||
|
|
||
| // 1. Verify broadcast source is extract from position 0 | ||
| const Register BroadcastSrcReg = MI.getOperand(1).getReg(); | ||
| const MachineInstr *ExtractMI = MRI.getVRegDef(BroadcastSrcReg); | ||
|
|
||
| if (!ExtractMI || !isGenericExtractOpcode(ExtractMI->getOpcode(), TII)) | ||
| return false; | ||
|
|
||
| // Verify extraction is from position 0 | ||
| const Register ExtractIdxReg = ExtractMI->getOperand(2).getReg(); | ||
| auto ExtractIdx = getIConstantVRegValWithLookThrough(ExtractIdxReg, MRI); | ||
| if (!ExtractIdx || ExtractIdx->Value.getZExtValue() != 0) | ||
| return false; | ||
|
|
||
| // Get the source vector that was extracted from | ||
| const Register ExtractSrcVecReg = ExtractMI->getOperand(1).getReg(); | ||
| const LLT ExtractSrcVecTy = MRI.getType(ExtractSrcVecReg); | ||
| const LLT BroadcastDstTy = MRI.getType(MI.getOperand(0).getReg()); | ||
|
|
||
| // Types must match exactly | ||
| if (ExtractSrcVecTy != BroadcastDstTy) | ||
| return false; | ||
|
|
||
| // 2. Verify all uses through the chain only extract position 0 | ||
| // using the helper function with single-use checks | ||
| const Register BroadcastDstReg = MI.getOperand(0).getReg(); | ||
| if (!verifyBroadcastUsesOnlyExtractZero(BroadcastDstReg, MRI, TII)) | ||
| return false; | ||
|
|
||
| MatchInfo = [ExtractSrcVecReg, BroadcastDstReg](MachineIRBuilder &B) { | ||
| B.buildCopy(BroadcastDstReg, ExtractSrcVecReg); | ||
| }; | ||
|
|
||
| return true; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whitespace change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I removed one unnecessary space after...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you guys figured out how to run clang format on td files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, sorry! I just saw this space and deleted it....