@@ -16316,7 +16316,12 @@ namespace {
16316
16316
// apply a combine.
16317
16317
struct CombineResult;
16318
16318
16319
- enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
16319
+ enum ExtKind : uint8_t {
16320
+ ZExt = 1 << 0,
16321
+ SExt = 1 << 1,
16322
+ FPExt = 1 << 2,
16323
+ BF16Ext = 1 << 3
16324
+ };
16320
16325
/// Helper class for folding sign/zero extensions.
16321
16326
/// In particular, this class is used for the following combines:
16322
16327
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16351,8 +16356,10 @@ struct NodeExtensionHelper {
16351
16356
/// instance, a splat constant (e.g., 3), would support being both sign and
16352
16357
/// zero extended.
16353
16358
bool SupportsSExt;
16354
- /// Records if this operand is like being floating-Point extended.
16359
+ /// Records if this operand is like being floating point extended.
16355
16360
bool SupportsFPExt;
16361
+ /// Records if this operand is extended from bf16.
16362
+ bool SupportsBF16Ext;
16356
16363
/// This boolean captures whether we care if this operand would still be
16357
16364
/// around after the folding happens.
16358
16365
bool EnforceOneUse;
@@ -16388,6 +16395,7 @@ struct NodeExtensionHelper {
16388
16395
case ExtKind::ZExt:
16389
16396
return RISCVISD::VZEXT_VL;
16390
16397
case ExtKind::FPExt:
16398
+ case ExtKind::BF16Ext:
16391
16399
return RISCVISD::FP_EXTEND_VL;
16392
16400
}
16393
16401
llvm_unreachable("Unknown ExtKind enum");
@@ -16409,13 +16417,6 @@ struct NodeExtensionHelper {
16409
16417
if (Source.getValueType() == NarrowVT)
16410
16418
return Source;
16411
16419
16412
- // vfmadd_vl -> vfwmadd_vl can take bf16 operands
16413
- if (Source.getValueType().getVectorElementType() == MVT::bf16) {
16414
- assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
16415
- Root->getOpcode() == RISCVISD::VFMADD_VL);
16416
- return Source;
16417
- }
16418
-
16419
16420
unsigned ExtOpc = getExtOpc(*SupportsExt);
16420
16421
16421
16422
// If we need an extension, we should be changing the type.
@@ -16458,7 +16459,8 @@ struct NodeExtensionHelper {
16458
16459
// Determine the narrow size.
16459
16460
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
16460
16461
16461
- MVT EltVT = SupportsExt == ExtKind::FPExt
16462
+ MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
16463
+ : SupportsExt == ExtKind::FPExt
16462
16464
? MVT::getFloatingPointVT(NarrowSize)
16463
16465
: MVT::getIntegerVT(NarrowSize);
16464
16466
@@ -16635,17 +16637,13 @@ struct NodeExtensionHelper {
16635
16637
EnforceOneUse = false;
16636
16638
}
16637
16639
16638
- bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
16639
- const RISCVSubtarget &Subtarget) {
16640
- // Any f16 extension will need zvfh
16641
- if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
16642
- return false;
16643
- // The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
16644
- // zvfbfwma
16645
- if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
16646
- Root->getOpcode() != RISCVISD::VFMADD_VL))
16647
- return false;
16648
- return true;
16640
+ bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16641
+ return (NarrowEltVT == MVT::f32 ||
16642
+ (NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()));
16643
+ }
16644
+
16645
+ bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16646
+ return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
16649
16647
}
16650
16648
16651
16649
/// Helper method to set the various fields of this struct based on the
@@ -16655,6 +16653,7 @@ struct NodeExtensionHelper {
16655
16653
SupportsZExt = false;
16656
16654
SupportsSExt = false;
16657
16655
SupportsFPExt = false;
16656
+ SupportsBF16Ext = false;
16658
16657
EnforceOneUse = true;
16659
16658
unsigned Opc = OrigOperand.getOpcode();
16660
16659
// For the nodes we handle below, we end up using their inputs directly: see
@@ -16686,9 +16685,11 @@ struct NodeExtensionHelper {
16686
16685
case RISCVISD::FP_EXTEND_VL: {
16687
16686
MVT NarrowEltVT =
16688
16687
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
16689
- if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
16690
- break;
16691
- SupportsFPExt = true;
16688
+ if (isSupportedFPExtend(NarrowEltVT, Subtarget))
16689
+ SupportsFPExt = true;
16690
+ if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
16691
+ SupportsBF16Ext = true;
16692
+
16692
16693
break;
16693
16694
}
16694
16695
case ISD::SPLAT_VECTOR:
@@ -16705,16 +16706,16 @@ struct NodeExtensionHelper {
16705
16706
if (Op.getOpcode() != ISD::FP_EXTEND)
16706
16707
break;
16707
16708
16708
- if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
16709
- Subtarget))
16710
- break;
16711
-
16712
16709
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
16713
16710
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
16714
16711
if (NarrowSize != ScalarBits)
16715
16712
break;
16716
16713
16717
- SupportsFPExt = true;
16714
+ if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
16715
+ SupportsFPExt = true;
16716
+ if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
16717
+ Subtarget))
16718
+ SupportsBF16Ext = true;
16718
16719
break;
16719
16720
}
16720
16721
default:
@@ -16947,6 +16948,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
16947
16948
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16948
16949
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
16949
16950
/*RHSExt=*/{ExtKind::FPExt});
16951
+ if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16952
+ RHS.SupportsBF16Ext)
16953
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16954
+ Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
16955
+ /*RHSExt=*/{ExtKind::BF16Ext});
16950
16956
return std::nullopt;
16951
16957
}
16952
16958
@@ -17029,6 +17035,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17029
17035
Subtarget);
17030
17036
}
17031
17037
17038
+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17039
+ ///
17040
+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17041
+ /// can be used to apply the pattern.
17042
+ static std::optional<CombineResult>
17043
+ canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17044
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17045
+ const RISCVSubtarget &Subtarget) {
17046
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17047
+ Subtarget);
17048
+ }
17049
+
17032
17050
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
17033
17051
///
17034
17052
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17068,6 +17086,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
17068
17086
case RISCVISD::VFNMADD_VL:
17069
17087
case RISCVISD::VFNMSUB_VL:
17070
17088
Strategies.push_back(canFoldToVWWithSameExtension);
17089
+ if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17090
+ Strategies.push_back(canFoldToVWWithBF16EXT);
17071
17091
break;
17072
17092
case ISD::MUL:
17073
17093
case RISCVISD::MUL_VL:
0 commit comments