Skip to content

Commit 04e2e58

Browse files
authored
[RISCV] Treat bf16->f32 as separate ExtKind in combineOp_VLToVWOp_VL. (#144653)
This allows us to better track the narrow type we need and to fix miscompiles if f16->f32 and bf16->f32 extends are mixed. Fixes #144651.
1 parent adc6228 commit 04e2e58

File tree

2 files changed

+103
-33
lines changed

2 files changed

+103
-33
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16316,7 +16316,12 @@ namespace {
1631616316
// apply a combine.
1631716317
struct CombineResult;
1631816318

16319-
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
16319+
enum ExtKind : uint8_t {
16320+
ZExt = 1 << 0,
16321+
SExt = 1 << 1,
16322+
FPExt = 1 << 2,
16323+
BF16Ext = 1 << 3
16324+
};
1632016325
/// Helper class for folding sign/zero extensions.
1632116326
/// In particular, this class is used for the following combines:
1632216327
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
@@ -16351,8 +16356,10 @@ struct NodeExtensionHelper {
1635116356
/// instance, a splat constant (e.g., 3), would support being both sign and
1635216357
/// zero extended.
1635316358
bool SupportsSExt;
16354-
/// Records if this operand is like being floating-Point extended.
16359+
/// Records if this operand is like being floating point extended.
1635516360
bool SupportsFPExt;
16361+
/// Records if this operand is extended from bf16.
16362+
bool SupportsBF16Ext;
1635616363
/// This boolean captures whether we care if this operand would still be
1635716364
/// around after the folding happens.
1635816365
bool EnforceOneUse;
@@ -16388,6 +16395,7 @@ struct NodeExtensionHelper {
1638816395
case ExtKind::ZExt:
1638916396
return RISCVISD::VZEXT_VL;
1639016397
case ExtKind::FPExt:
16398+
case ExtKind::BF16Ext:
1639116399
return RISCVISD::FP_EXTEND_VL;
1639216400
}
1639316401
llvm_unreachable("Unknown ExtKind enum");
@@ -16409,13 +16417,6 @@ struct NodeExtensionHelper {
1640916417
if (Source.getValueType() == NarrowVT)
1641016418
return Source;
1641116419

16412-
// vfmadd_vl -> vfwmadd_vl can take bf16 operands
16413-
if (Source.getValueType().getVectorElementType() == MVT::bf16) {
16414-
assert(Root->getSimpleValueType(0).getVectorElementType() == MVT::f32 &&
16415-
Root->getOpcode() == RISCVISD::VFMADD_VL);
16416-
return Source;
16417-
}
16418-
1641916420
unsigned ExtOpc = getExtOpc(*SupportsExt);
1642016421

1642116422
// If we need an extension, we should be changing the type.
@@ -16458,7 +16459,8 @@ struct NodeExtensionHelper {
1645816459
// Determine the narrow size.
1645916460
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1646016461

16461-
MVT EltVT = SupportsExt == ExtKind::FPExt
16462+
MVT EltVT = SupportsExt == ExtKind::BF16Ext ? MVT::bf16
16463+
: SupportsExt == ExtKind::FPExt
1646216464
? MVT::getFloatingPointVT(NarrowSize)
1646316465
: MVT::getIntegerVT(NarrowSize);
1646416466

@@ -16635,17 +16637,13 @@ struct NodeExtensionHelper {
1663516637
EnforceOneUse = false;
1663616638
}
1663716639

16638-
bool isSupportedFPExtend(SDNode *Root, MVT NarrowEltVT,
16639-
const RISCVSubtarget &Subtarget) {
16640-
// Any f16 extension will need zvfh
16641-
if (NarrowEltVT == MVT::f16 && !Subtarget.hasVInstructionsF16())
16642-
return false;
16643-
// The only bf16 extension we can do is vfmadd_vl -> vfwmadd_vl with
16644-
// zvfbfwma
16645-
if (NarrowEltVT == MVT::bf16 && (!Subtarget.hasStdExtZvfbfwma() ||
16646-
Root->getOpcode() != RISCVISD::VFMADD_VL))
16647-
return false;
16648-
return true;
16640+
bool isSupportedFPExtend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16641+
return (NarrowEltVT == MVT::f32 ||
16642+
(NarrowEltVT == MVT::f16 && Subtarget.hasVInstructionsF16()));
16643+
}
16644+
16645+
bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
16646+
return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
1664916647
}
1665016648

1665116649
/// Helper method to set the various fields of this struct based on the
@@ -16655,6 +16653,7 @@ struct NodeExtensionHelper {
1665516653
SupportsZExt = false;
1665616654
SupportsSExt = false;
1665716655
SupportsFPExt = false;
16656+
SupportsBF16Ext = false;
1665816657
EnforceOneUse = true;
1665916658
unsigned Opc = OrigOperand.getOpcode();
1666016659
// For the nodes we handle below, we end up using their inputs directly: see
@@ -16686,9 +16685,11 @@ struct NodeExtensionHelper {
1668616685
case RISCVISD::FP_EXTEND_VL: {
1668716686
MVT NarrowEltVT =
1668816687
OrigOperand.getOperand(0).getSimpleValueType().getVectorElementType();
16689-
if (!isSupportedFPExtend(Root, NarrowEltVT, Subtarget))
16690-
break;
16691-
SupportsFPExt = true;
16688+
if (isSupportedFPExtend(NarrowEltVT, Subtarget))
16689+
SupportsFPExt = true;
16690+
if (isSupportedBF16Extend(NarrowEltVT, Subtarget))
16691+
SupportsBF16Ext = true;
16692+
1669216693
break;
1669316694
}
1669416695
case ISD::SPLAT_VECTOR:
@@ -16705,16 +16706,16 @@ struct NodeExtensionHelper {
1670516706
if (Op.getOpcode() != ISD::FP_EXTEND)
1670616707
break;
1670716708

16708-
if (!isSupportedFPExtend(Root, Op.getOperand(0).getSimpleValueType(),
16709-
Subtarget))
16710-
break;
16711-
1671216709
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
1671316710
unsigned ScalarBits = Op.getOperand(0).getValueSizeInBits();
1671416711
if (NarrowSize != ScalarBits)
1671516712
break;
1671616713

16717-
SupportsFPExt = true;
16714+
if (isSupportedFPExtend(Op.getOperand(0).getSimpleValueType(), Subtarget))
16715+
SupportsFPExt = true;
16716+
if (isSupportedBF16Extend(Op.getOperand(0).getSimpleValueType(),
16717+
Subtarget))
16718+
SupportsBF16Ext = true;
1671816719
break;
1671916720
}
1672016721
default:
@@ -16947,6 +16948,11 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694716948
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1694816949
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694916950
/*RHSExt=*/{ExtKind::FPExt});
16951+
if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16952+
RHS.SupportsBF16Ext)
16953+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
16954+
Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
16955+
/*RHSExt=*/{ExtKind::BF16Ext});
1695016956
return std::nullopt;
1695116957
}
1695216958

@@ -17029,6 +17035,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1702917035
Subtarget);
1703017036
}
1703117037

17038+
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17039+
///
17040+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17041+
/// can be used to apply the pattern.
17042+
static std::optional<CombineResult>
17043+
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17044+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17045+
const RISCVSubtarget &Subtarget) {
17046+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17047+
Subtarget);
17048+
}
17049+
1703217050
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
1703317051
///
1703417052
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17068,6 +17086,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1706817086
case RISCVISD::VFNMADD_VL:
1706917087
case RISCVISD::VFNMSUB_VL:
1707017088
Strategies.push_back(canFoldToVWWithSameExtension);
17089+
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17090+
Strategies.push_back(canFoldToVWWithBF16EXT);
1707117091
break;
1707217092
case ISD::MUL:
1707317093
case RISCVISD::MUL_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmaccbf16.ll

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
3-
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
4-
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
5-
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
2+
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
3+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfwma -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFWMA
4+
; RUN: llc < %s -mtriple=riscv32 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
5+
; RUN: llc < %s -mtriple=riscv64 -mattr=+v,+zvfh,+zvfbfmin -verify-machineinstrs | FileCheck %s --check-prefix=ZVFBFMIN
66

77
define <1 x float> @vfwmaccbf16_vv_v1f32(<1 x float> %a, <1 x bfloat> %b, <1 x bfloat> %c) {
88
; ZVFBFWMA-LABEL: vfwmaccbf16_vv_v1f32:
@@ -295,3 +295,53 @@ define <32 x float> @vfwmaccbf32_vf_v32f32(<32 x float> %a, bfloat %b, <32 x bfl
295295
%res = call <32 x float> @llvm.fma.v32f32(<32 x float> %b.ext, <32 x float> %c.ext, <32 x float> %a)
296296
ret <32 x float> %res
297297
}
298+
299+
define <4 x float> @vfwmaccbf16_vf_v4f32_scalar_extend(<4 x float> %rd, bfloat %a, <4 x bfloat> %b) local_unnamed_addr #0 {
300+
; ZVFBFWMA-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
301+
; ZVFBFWMA: # %bb.0:
302+
; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
303+
; ZVFBFWMA-NEXT: vfwmaccbf16.vf v8, fa0, v9
304+
; ZVFBFWMA-NEXT: ret
305+
;
306+
; ZVFBFMIN-LABEL: vfwmaccbf16_vf_v4f32_scalar_extend:
307+
; ZVFBFMIN: # %bb.0:
308+
; ZVFBFMIN-NEXT: fmv.x.w a0, fa0
309+
; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
310+
; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v10, v9
311+
; ZVFBFMIN-NEXT: slli a0, a0, 16
312+
; ZVFBFMIN-NEXT: fmv.w.x fa5, a0
313+
; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
314+
; ZVFBFMIN-NEXT: vfmacc.vf v8, fa5, v10
315+
; ZVFBFMIN-NEXT: ret
316+
%b_ext = fpext <4 x bfloat> %b to <4 x float>
317+
%a_extend = fpext bfloat %a to float
318+
%a_insert = insertelement <4 x float> poison, float %a_extend, i64 0
319+
%a_shuffle = shufflevector <4 x float> %a_insert, <4 x float> poison, <4 x i32> zeroinitializer
320+
%fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_shuffle, <4 x float> %b_ext, <4 x float> %rd)
321+
ret <4 x float> %fma
322+
}
323+
324+
; Negative test with a mix of bfloat and half fpext.
325+
define <4 x float> @mix(<4 x float> %rd, <4 x half> %a, <4 x bfloat> %b) {
326+
; ZVFBFWMA-LABEL: mix:
327+
; ZVFBFWMA: # %bb.0:
328+
; ZVFBFWMA-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
329+
; ZVFBFWMA-NEXT: vfwcvt.f.f.v v11, v9
330+
; ZVFBFWMA-NEXT: vfwcvtbf16.f.f.v v9, v10
331+
; ZVFBFWMA-NEXT: vsetvli zero, zero, e32, m1, ta, ma
332+
; ZVFBFWMA-NEXT: vfmacc.vv v8, v11, v9
333+
; ZVFBFWMA-NEXT: ret
334+
;
335+
; ZVFBFMIN-LABEL: mix:
336+
; ZVFBFMIN: # %bb.0:
337+
; ZVFBFMIN-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
338+
; ZVFBFMIN-NEXT: vfwcvt.f.f.v v11, v9
339+
; ZVFBFMIN-NEXT: vfwcvtbf16.f.f.v v9, v10
340+
; ZVFBFMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
341+
; ZVFBFMIN-NEXT: vfmacc.vv v8, v11, v9
342+
; ZVFBFMIN-NEXT: ret
343+
%a_ext = fpext <4 x half> %a to <4 x float>
344+
%b_ext = fpext <4 x bfloat> %b to <4 x float>
345+
%fma = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %a_ext, <4 x float> %b_ext, <4 x float> %rd)
346+
ret <4 x float> %fma
347+
}

0 commit comments

Comments
 (0)