@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
499
499
// / and LLVM AMDGPU intrinsics convention.
500
500
// /
501
501
// / Specifically:
502
- // / 1. If the element type is bfloat16, bitcast it to i16.
502
+ // / 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503
+ // / allows bf16. Newer MFMAs support bf16 types on operand, check
504
+ // / IntrinsicsAMDGPU.td file for reference.
503
505
// / 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
504
506
// / instead, which is what the f8f6f4 intrinsics use.
505
507
// / 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
509
511
// / therefore 8-bit and smaller floats are represented as their corresponding
510
512
// / `iN` integers.
511
513
static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
512
- Location loc, Value input) {
514
+ Location loc, Value input,
515
+ bool allowBf16 = true ) {
513
516
Type inputType = input.getType ();
514
517
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
515
- if (vectorType.getElementType ().isBF16 ())
518
+ if (vectorType.getElementType ().isBF16 () && !allowBf16 )
516
519
return rewriter.create <LLVM::BitcastOp>(
517
520
loc, vectorType.clone (rewriter.getI16Type ()), input);
518
521
if (vectorType.getElementType ().isInteger (8 ) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
958
961
959
962
StringRef intrinsicName =
960
963
isScaled ? std::get<0 >(*maybeScaledIntrinsic) : *maybeIntrinsic;
964
+ // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965
+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966
+ bool allowBf16 = [&]() {
967
+ if (chipset < kGfx950 )
968
+ return false ;
969
+ if (isScaled)
970
+ return true ;
971
+ return intrinsicName.contains (" 16x16x32.bf16" ) ||
972
+ intrinsicName.contains (" 32x32x16.bf16" );
973
+ }();
961
974
OperationState loweredOp (loc, intrinsicName);
962
975
loweredOp.addTypes (intrinsicOutType);
963
- loweredOp.addOperands (
964
- {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
965
- convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
966
- adaptor.getDestC ()});
976
+ loweredOp.addOperands ({convertMFMAVectorOperand (
977
+ rewriter, loc, adaptor.getSourceA (), allowBf16),
978
+ convertMFMAVectorOperand (
979
+ rewriter, loc, adaptor.getSourceB (), allowBf16),
980
+ adaptor.getDestC ()});
967
981
if (isScaled) {
968
982
Value zero = createI32Constant (rewriter, loc, 0 );
969
983
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
0 commit comments