Skip to content

Commit a858ab0

Browse files
author
Martien de Jong
committed
[AIE] Supply vector implementation for FMUL widened from S16
This is a combine pattern that pushes S32 FNEGs through FPEXT These FNEGs prevent standard InstrCombines, and appear in some kernels. Make buildScalarAsVector switchable by cl option to either use broadcast or insert in undef. The latter is easier for a follow-up combine, but is less clean because it may have NaN effects on flags. We wired the default to broadcast, since the combine to cover the important case is implemented already We also fix some cosmetics from earlier commits
1 parent 784e0fa commit a858ab0

File tree

5 files changed

+296
-16
lines changed

5 files changed

+296
-16
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ def combine_narrow_zext_s20 : GICombineRule<
195195
[{ return matchNarrowZext(*${root}, MRI, Observer, ${matchinfo}); }]),
196196
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
197197

198+
def combine_widen_fmul : GICombineRule<
199+
(defs root:$root, build_fn_matchinfo:$matchinfo),
200+
(match (wip_match_opcode G_FMUL): $root,
201+
[{ return matchWidenFMul(*${root}, MRI, Observer, ${matchinfo}); }]),
202+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
203+
198204
def concat_unmerge_matchdata : GIDefMatchData<"AIEConcatUnmergeCombineMatchData">;
199205
def combine_concat_unmerge_phis : GICombineRule <
200206
(defs root:$root, concat_unmerge_matchdata:$matchinfo),
@@ -299,6 +305,7 @@ def aie2p_additional_combines : GICombineGroup<[
299305
combine_vector_shuffle_to_extract_insert_elt,
300306
combine_vector_shuffle_concat_extracted_subvectors,
301307
combine_paired_extracts,
308+
combine_widen_fmul,
302309
combine_vector_shuffle_to_extract_insert_elt_to_broadcast,
303310
combine_bitcast_unmerge_swap,
304311
combine_phi_bitcast_swap

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 155 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ static cl::opt<bool> EnableGreedyAddressCombine(
5151
cl::desc("Enable greedy combines without checking for later uses of the "
5252
"base pointer"));
5353

54+
static cl::opt<bool> PreferBroadcastOverInsert(
55+
"aie-prefer-broadcast-over-insert", cl::Hidden, cl::init(true),
56+
cl::desc("Use broadcast rather than insert-in-undefined to create "
57+
"scalar values in vector"));
58+
5459
cl::opt<bool> InlineMemCalls("aie-inline-mem-calls", cl::init(true), cl::Hidden,
5560
cl::desc("Inline mem calls when profitable."));
5661

@@ -68,6 +73,15 @@ cl::opt<bool> MemsetOptimizations(
6873

6974
namespace {
7075

76+
static constexpr const LLT S8 = LLT::scalar(8);
77+
static constexpr const LLT S16 = LLT::scalar(16);
78+
static constexpr const LLT S32 = LLT::scalar(32);
79+
static constexpr const LLT V32S16 = LLT::fixed_vector(32, 16);
80+
81+
const llvm::AIEBaseInstrInfo &getAIETII(MachineIRBuilder &B) {
82+
return static_cast<const AIEBaseInstrInfo &>(B.getTII());
83+
}
84+
7185
bool isGenericExtractOpcode(unsigned Opc, const AIEBaseInstrInfo &TII) {
7286
// Check if it's either SEXT or ZEXT extract
7387
const unsigned ExtractSextOpc = TII.getGenericExtractVectorEltOpcode(true);
@@ -123,40 +137,108 @@ bool verifyBroadcastUsesOnlyExtractZero(Register Reg, MachineRegisterInfo &MRI,
123137
MRI, TII);
124138
// For unmerge, the useful operand should be the first one,
125139
// the other ones, they should be dead.
126-
} else if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
140+
}
141+
if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
127142
unsigned OpCount = 0;
128143
for (auto &MO : UserMI->defs()) {
129144
Register DefReg = MO.getReg();
130145
if (OpCount == 0 && !MRI.hasOneUse(DefReg))
131146
return false;
132-
else if (OpCount && !MRI.use_empty(DefReg))
147+
if (OpCount && !MRI.use_empty(DefReg))
133148
return false;
134149
OpCount++;
135150
}
136151
return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(),
137152
MRI, TII);
138153
// If we extract from zero, we succeed, otherwise we fail.
139-
} else if (isGenericExtractOpcode(Opcode, TII)) {
154+
}
155+
if (isGenericExtractOpcode(Opcode, TII)) {
140156
const Register UseIdxReg = UserMI->getOperand(2).getReg();
141157
auto UseIdx = getIConstantVRegValWithLookThrough(UseIdxReg, MRI);
142158
return UseIdx && UseIdx->Value.getZExtValue() == 0;
143159
// If we bitcast, we may need other lanes.
144-
} else if (Opcode == TargetOpcode::G_BITCAST) {
160+
}
161+
if (Opcode == TargetOpcode::G_BITCAST) {
162+
return false;
163+
}
164+
if (mayMIShiftElements(UserMI)) {
145165
return false;
146-
} else {
147-
if (mayMIShiftElements(UserMI))
148-
return false;
149-
return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(),
150-
MRI, TII);
151166
}
152167

153-
return false;
168+
return verifyBroadcastUsesOnlyExtractZero(UserMI->getOperand(0).getReg(), MRI,
169+
TII);
154170
}
155171

156-
} // namespace
157172

158-
static unsigned getNumMaskUndefs(const ArrayRef<int> &Mask,
159-
unsigned StartIndex) {
173+
Register buildInsertInUndef(MachineIRBuilder &B, Register Src, LLT VecTy) {
174+
auto *MRI = B.getMRI();
175+
if (MRI->getType(Src) != S32) {
176+
Src = B.buildAnyExt(S32, Src).getReg(0);
177+
}
178+
const AIEBaseInstrInfo &TII = getAIETII(B);
179+
const Register IdxReg = B.buildConstant(S32, 0).getReg(0);
180+
const Register UndefVec = B.buildUndef(VecTy).getReg(0);
181+
const unsigned InsertEltOpc = TII.getGenericInsertVectorEltOpcode();
182+
Register Vector =
183+
B.buildInstr(InsertEltOpc, {VecTy}, {UndefVec, Src, IdxReg}).getReg(0);
184+
185+
return Vector;
186+
}
187+
188+
Register buildBroadcast(MachineIRBuilder &B, Register Src, LLT VecTy) {
189+
auto *MRI = B.getMRI();
190+
if (MRI->getType(Src) != S32) {
191+
Src = B.buildAnyExt(S32, Src).getReg(0);
192+
}
193+
const AIEBaseInstrInfo &TII = getAIETII(B);
194+
const unsigned InsertEltOpc = TII.getGenericBroadcastVectorOpcode();
195+
Register Vector = B.buildInstr(InsertEltOpc, {VecTy}, {Src}).getReg(0);
196+
197+
return Vector;
198+
}
199+
200+
Register buildScalarAsVector(MachineIRBuilder &B, Register Src, LLT VecTy) {
201+
return PreferBroadcastOverInsert ? buildBroadcast(B, Src, VecTy)
202+
: buildInsertInUndef(B, Src, VecTy);
203+
}
204+
205+
// Build an element-wise multiplication into a vector of double width. These are
206+
// typical MAC operations with the incoming accumulator configured to be zero.
207+
// If Negate is true, uses the negating multiply intrinsic.
208+
Register buildWidenMulScalarAsVector(MachineIRBuilder &B, Register Lft,
209+
Register Rgt, bool Negate) {
210+
// Mode and intrinsic are target dependent.
211+
auto *MRI = B.getMRI();
212+
const int MulMode1x1 = 60;
213+
LLT InTy = MRI->getType(Lft);
214+
LLT OutTy = InTy.changeElementSize(InTy.getScalarSizeInBits() * 2);
215+
const Register Acc = B.getMRI()->createGenericVirtualRegister(OutTy);
216+
const Register Mode = B.buildConstant(S32, MulMode1x1).getReg(0);
217+
218+
// Choose the appropriate intrinsic based on whether we need negation.
219+
// Both bf_mul_conf and bf_negmul_conf use the same mode parameter, which
220+
// controls data types and multiplication configuration (see VecConf in
221+
// AIE2PInstrPatterns.td). The intrinsic opcode controls the negation
222+
// behavior via the dynMulNeg bit in the underlying instruction.
223+
const Intrinsic::ID IntrID =
224+
Negate ? Intrinsic::aie2p_I512_I512_ACC1024_bf_negmul_conf
225+
: Intrinsic::aie2p_I512_I512_ACC1024_bf_mul_conf;
226+
227+
B.buildIntrinsic(IntrID, Acc, true, false)
228+
.addUse(Lft)
229+
.addUse(Rgt)
230+
.addUse(Mode);
231+
return Acc;
232+
}
233+
234+
void buildFirstElement(MachineIRBuilder &B, Register DstReg, Register Vec) {
235+
const AIEBaseInstrInfo &TII = getAIETII(B);
236+
const Register Index = B.buildConstant(S32, 0).getReg(0);
237+
B.buildInstr(TII.getGenericExtractVectorEltOpcode(/*SignExt*/ true), {DstReg},
238+
{Vec, Index});
239+
}
240+
241+
unsigned getNumMaskUndefs(const ArrayRef<int> &Mask, unsigned StartIndex) {
160242
unsigned Count = 0;
161243
for (unsigned I = StartIndex; I < Mask.size(); ++I) {
162244
if (Mask[I] == -1) {
@@ -166,6 +248,8 @@ static unsigned getNumMaskUndefs(const ArrayRef<int> &Mask,
166248
return Count;
167249
}
168250

251+
} // namespace
252+
169253
bool MaskMatch::isValidMask(const ArrayRef<int> Mask) const {
170254
for (unsigned Idx = 0; Idx < Mask.size(); ++Idx) {
171255
if (Mask[Idx] == -1)
@@ -1144,8 +1228,6 @@ bool llvm::matchExtractVecEltAndExt(
11441228
assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT &&
11451229
"Expected a extract_vector_elt");
11461230
Register DstReg = MI.getOperand(0).getReg();
1147-
const LLT S8 = LLT::scalar(8);
1148-
const LLT S16 = LLT::scalar(16);
11491231
LLT SrcVecTy = MRI.getType(MI.getOperand(1).getReg());
11501232
// Extracts from vectors <= 64-bits are lowered to bit-arithmetic in
11511233
// legalization
@@ -3609,6 +3691,64 @@ bool llvm::matchNarrowZext(MachineInstr &MI, MachineRegisterInfo &MRI,
36093691
return false;
36103692
}
36113693

3694+
namespace {
3695+
// We match widenings from 16 bit, with possible negations on top.
3696+
// Negations commute with conversions and multiplications. We keep track of the
3697+
// total number of negations modulo two.
3698+
class ExtendOperand {
3699+
public:
3700+
Register Source{};
3701+
bool Negate = false;
3702+
ExtendOperand operator-() { return {Source, !Negate}; }
3703+
operator bool() { return Source; }
3704+
};
3705+
3706+
ExtendOperand matchExtend(Register SrcReg, MachineRegisterInfo &MRI) {
3707+
const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
3708+
if (SrcMI->getOpcode() == TargetOpcode::G_FPEXT) {
3709+
const Register HalfOp = SrcMI->getOperand(1).getReg();
3710+
if (MRI.getType(HalfOp) != S16) {
3711+
return {};
3712+
}
3713+
return {HalfOp, false};
3714+
}
3715+
if (SrcMI->getOpcode() == TargetOpcode::G_FNEG) {
3716+
return -matchExtend(SrcMI->getOperand(1).getReg(), MRI);
3717+
}
3718+
return {};
3719+
}
3720+
} // namespace
3721+
3722+
bool llvm::matchWidenFMul(MachineInstr &FMul, MachineRegisterInfo &MRI,
3723+
GISelChangeObserver &Observer, BuildFnTy &MatchInfo) {
3724+
if (!FMul.getMF()->getTarget().getTargetTriple().isAIE2P()) {
3725+
return false;
3726+
}
3727+
3728+
ExtendOperand Lft = matchExtend(FMul.getOperand(1).getReg(), MRI);
3729+
if (!Lft) {
3730+
return false;
3731+
}
3732+
ExtendOperand Rgt = matchExtend(FMul.getOperand(2).getReg(), MRI);
3733+
if (!Rgt) {
3734+
return false;
3735+
}
3736+
3737+
const Register DstReg = FMul.getOperand(0).getReg();
3738+
const bool Negate = Lft.Negate ^ Rgt.Negate;
3739+
3740+
// We build extract(mul(tovector(Lft), tovector(Rgt)), 0)
3741+
MatchInfo = [=](MachineIRBuilder &B) {
3742+
const LLT VecTy = V32S16;
3743+
const Register VLhs = buildScalarAsVector(B, Lft.Source, VecTy);
3744+
const Register VRhs = buildScalarAsVector(B, Rgt.Source, VecTy);
3745+
const Register Acc = buildWidenMulScalarAsVector(B, VLhs, VRhs, Negate);
3746+
buildFirstElement(B, DstReg, Acc);
3747+
};
3748+
3749+
return true;
3750+
}
3751+
36123752
// Fold G_TRUNC (G_[ANY|S|Z]EXT x) -> X or (G_[ANY|S|Z]EXT x) or (G_TRUNC x).
36133753
bool llvm::matchCombineExtAndTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
36143754
BuildFnTy &MatchInfo) {

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ bool matchNarrowTruncConstant(MachineInstr &MI, MachineRegisterInfo &MRI,
272272
bool matchNarrowZext(MachineInstr &MI, MachineRegisterInfo &MRI,
273273
GISelChangeObserver &Observer, BuildFnTy &MatchInfo);
274274

275+
bool matchWidenFMul(MachineInstr &MI, MachineRegisterInfo &MRI,
276+
GISelChangeObserver &Observer, BuildFnTy &MatchInfo);
277+
275278
bool matchCombineExtAndTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
276279
BuildFnTy &MatchInfo);
277280

llvm/lib/Target/AIE/AIELegalizerHelper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ bool AIELegalizerHelper::legalizeG_FMUL(LegalizerHelper &Helper,
14051405
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
14061406

14071407
const Register DstReg = MI.getOperand(0).getReg();
1408-
assert(MRI.getType(DstReg) == LLT::scalar(16) &&
1408+
assert(MRI.getType(DstReg) == S16 &&
14091409
"Expected bfloat16 type in custom legalization.");
14101410

14111411
Register SrcLHS = MI.getOperand(1).getReg();

0 commit comments

Comments
 (0)