Skip to content

Commit 836201f

Browse files
authored
Allow bf16 operands on new MFMAs (#144925)
New gfx950 MFMA allows bf16 operands. https://github.com/llvm/llvm-project/blob/c0cc81cdc03c97473ba771bbc3a2330bd22396bc/llvm/include/llvm/IR/IntrinsicsAMDGPU.td#L3434 When running `amdgpu-to-rocdl`, Current logic converts bf16 to i16 always which fails to compile for newer bf16 MFMA e.g. `v_mfma_f32_16x16x32bf16`. Backend expects bf16 type for the operands for those newer MFMAs. This patch fixes it. CC: @krzysz00 @dhernandez0 @giuseros @antiagainst @kuhar
1 parent f780955 commit 836201f

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
499499
/// and LLVM AMDGPU intrinsics convention.
500500
///
501501
/// Specifically:
502-
/// 1. If the element type is bfloat16, bitcast it to i16.
502+
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
503+
/// allows bf16. Newer MFMAs support bf16 types on operand, check
504+
/// IntrinsicsAMDGPU.td file for reference.
503505
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
504506
/// instead, which is what the f8f6f4 intrinsics use.
505507
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
509511
/// therefore 8-bit and smaller floats are represented as their corresponding
510512
/// `iN` integers.
511513
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
512-
Location loc, Value input) {
514+
Location loc, Value input,
515+
bool allowBf16 = true) {
513516
Type inputType = input.getType();
514517
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
515-
if (vectorType.getElementType().isBF16())
518+
if (vectorType.getElementType().isBF16() && !allowBf16)
516519
return rewriter.create<LLVM::BitcastOp>(
517520
loc, vectorType.clone(rewriter.getI16Type()), input);
518521
if (vectorType.getElementType().isInteger(8) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
958961

959962
StringRef intrinsicName =
960963
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
964+
// Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
965+
// allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
966+
bool allowBf16 = [&]() {
967+
if (chipset < kGfx950)
968+
return false;
969+
if (isScaled)
970+
return true;
971+
return intrinsicName.contains("16x16x32.bf16") ||
972+
intrinsicName.contains("32x32x16.bf16");
973+
}();
961974
OperationState loweredOp(loc, intrinsicName);
962975
loweredOp.addTypes(intrinsicOutType);
963-
loweredOp.addOperands(
964-
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
965-
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
966-
adaptor.getDestC()});
976+
loweredOp.addOperands({convertMFMAVectorOperand(
977+
rewriter, loc, adaptor.getSourceA(), allowBf16),
978+
convertMFMAVectorOperand(
979+
rewriter, loc, adaptor.getSourceB(), allowBf16),
980+
adaptor.getDestC()});
967981
if (isScaled) {
968982
Value zero = createI32Constant(rewriter, loc, 0);
969983
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
1111
amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
1212
// CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
1313
amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
14-
// CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
14+
// CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
1515
amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
16-
// CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
16+
// CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
1717
amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
1818
// CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
1919
amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>

0 commit comments

Comments
 (0)