diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 074404add47f1..700563460f525 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { /// and LLVM AMDGPU intrinsics convention. /// /// Specifically: -/// 1. If the element type is bfloat16, bitcast it to i16. +/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic +/// allows bf16. Newer MFMAs support bf16 types on operand, check +/// IntrinsicsAMDGPU.td file for reference. /// 2. If instead we have a more than 64-bit quantity, use a /// instead, which is what the f8f6f4 intrinsics use. /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit @@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern { /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, - Location loc, Value input) { + Location loc, Value input, + bool allowBf16 = true) { Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { - if (vectorType.getElementType().isBF16()) + if (vectorType.getElementType().isBF16() && !allowBf16) return rewriter.create( loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && @@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern { StringRef intrinsicName = isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic; + // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+ + // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file. + bool allowBf16 = [&]() { + if (chipset < kGfx950) + return false; + if (isScaled) + return true; + return intrinsicName.contains("16x16x32.bf16") || + intrinsicName.contains("32x32x16.bf16"); + }(); OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); - loweredOp.addOperands( - {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), - convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), - adaptor.getDestC()}); + loweredOp.addOperands({convertMFMAVectorOperand( + rewriter, loc, adaptor.getSourceA(), allowBf16), + convertMFMAVectorOperand( + rewriter, loc, adaptor.getSourceB(), allowBf16), + adaptor.getDestC()}); if (isScaled) { Value zero = createI32Constant(rewriter, loc, 0); auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir index 52a5d39f668c6..39c31d5bf2fa3 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir @@ -11,9 +11,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>, amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32> - // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32> - // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32> // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32> amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>