Skip to content

[mlir][spirv] Add 8-bit float type emulation #148811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
];
}

Expand Down Expand Up @@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
];
}

Expand Down Expand Up @@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
];
}

Expand Down Expand Up @@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};

/// Whether to emulate unsupported floats with integer types of same bit
/// width.
bool emulateUnsupportedFloatTypes{true};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we want this on by default? I think this can be a footgun when users inadvertently use unsupported fp types and get a dialect conversion error over integer types down the line...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather this was opt-in unless we have good justification for keeping this opt-out

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the theory is that you'll generally want to do software emulation of the small floats earlier?

Copy link
Member

@kuhar kuhar Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about cases when you ended up with unsupported types by accident -- this often comes up with all the variants of fp8 and smaller fp types that are present in the input MLIR at the level of linalg/arith. This is much easier to diagnose when you can see the original type. IMO, dialect conversion should error out by default, unless someone opts into these types being handled in some other way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM precedent is to just do the f8* => i8, though?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm ir or convert-to-llvm? llvm has the same issue as spirv here that its type system has fewer primitive types than mlir, so it's on mlir to figure out how to handle unsupported types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert-to-llvm, which, IIRC, doesn't even provide this opt-out mechanism, it' just defines the FP8 types to be i8

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so we can leave this on by default to match llvm conversion if this is the case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @krzysz00 . Yes, I was following the llvm precedence.


/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};

Expand Down
38 changes: 34 additions & 4 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}

// Get in IntegerAttr from FloatAttr while preserving the bits.
// Useful for converting float constants to integer constants while preserving
// the bits.
static IntegerAttr
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
ConversionPatternRewriter &rewriter) {
APFloat floatVal = floatAttr.getValue();
APInt intVal = floatVal.bitcastToAPInt();
return rewriter.getIntegerAttr(dstType, intVal);
}

/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
Expand Down Expand Up @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr =
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
Attribute dstAttr = nullptr;
// Handle 8-bit float conversion to 8-bit integer.
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcElemType.getIntOrFloatBitWidth() == 8 &&
isa<IntegerType>(dstElemType)) {
dstAttr =
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
} else {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
rewriter);
}
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
Expand Down Expand Up @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
auto dstAttr = srcAttr;
Attribute dstAttr = srcAttr;

// Floating-point types not supported in the target environment are all
// converted to float type.
if (srcType != dstType) {
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
dstType.getIntOrFloatBitWidth() == 8) {
// If the source is an 8-bit float, convert it to a 8-bit integer.
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
if (!dstAttr)
return failure();
} else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
Expand Down Expand Up @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

// Use UnrealizedConversionCast as the bridge so that we don't need to pull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

// TODO: We should also take care of block argument type conversion.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

RewritePatternSet patterns(context);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

RewritePatternSet patterns(context);
Expand Down
54 changes: 54 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}

// Handle 8-bit floats.
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
if (bitWidth == 8)
return bitWidth / 8;
return std::nullopt;
}

if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
Expand Down Expand Up @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}

/// Converts 8-bit float types to integer types with the same bit width.
/// Returns a nullptr for unsupported 8-bit float types.
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
FloatType type) {
if (!options.emulateUnsupportedFloatTypes)
return nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(type))
return IntegerType::get(type.getContext(), type.getWidth());
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
return nullptr;
}

/// Returns a type with the same shape but with any 8-bit float element type
/// converted to the same bit width integer type. This is a noop when the
/// element type is not the 8-bit float type or emulation flag is set to false.
static ShapedType
convertShaped8BitFloatType(ShapedType type,
const SPIRVConversionOptions &options) {
if (!options.emulateUnsupportedFloatTypes)
return type;
Type srcElementType = type.getElementType();
Type convertedElementType = nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(srcElementType))
convertedElementType = IntegerType::get(
type.getContext(), srcElementType.getIntOrFloatBitWidth());

if (!convertedElementType)
return type;

return type.clone(convertedElementType);
}

/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
Expand All @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
Expand Down Expand Up @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}

type = cast<TensorType>(convertIndexElementType(type, options));
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
// Hnadle 8 bit float types.
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
Expand Down Expand Up @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
if (floatType.getWidth() == 8)
return convert8BitFloatType(this->options, floatType);
return Type();
});

Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,23 @@ func.func @constant() {
return
}

// CHECK-LABEL: @constant_8bit_float
func.func @constant_8bit_float() {
// CHECK: spirv.Constant 56 : i8
%cst = arith.constant 1.0 : f8E4M3
// CHECK: spirv.Constant 56 : i8
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
return
}

// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT

//===----------------------------------------------------------------------===//
// Integer types
Expand Down Expand Up @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }

} // end module


// -----

// Check that 8-bit float types are emulated as i8.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
} {

// CHECK: spirv.func @float8_to_integer8
// CHECK-SAME: (%arg0: i8
// CHECK-SAME: %arg1: i8
// CHECK-SAME: %arg2: i8
// CHECK-SAME: %arg3: i8
// CHECK-SAME: %arg4: i8
// CHECK-SAME: %arg5: i8
// CHECK-SAME: %arg6: i8
// CHECK-SAME: %arg7: i8
// CHECK-SAME: %arg8: vector<4xi8>
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
// UNSUPPORTED_FLOAT-SAME: ) {

func.func @float8_to_integer8(
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
) {
// CHECK: spirv.Return
return
}
}