diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml
index 92cdbb70e9858..93c201c3c6d60 100644
--- a/.github/workflows/linux_minimal_build.yml
+++ b/.github/workflows/linux_minimal_build.yml
@@ -325,7 +325,7 @@ jobs:
--build_wheel
--use_binskim_compliant_compile_flags
--disable_ml_ops
- --disable_types sparsetensor float8 optional
+ --disable_types sparsetensor float4 float8 optional
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -341,7 +341,7 @@ jobs:
--build_wheel
--use_binskim_compliant_compile_flags
--disable_ml_ops
- --disable_types sparsetensor float8 optional
+ --disable_types sparsetensor float4 float8 optional
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -358,7 +358,7 @@ jobs:
--build_wheel
--use_binskim_compliant_compile_flags
--disable_ml_ops
- --disable_types sparsetensor float8 optional
+ --disable_types sparsetensor float4 float8 optional
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -408,7 +408,7 @@ jobs:
--disable_ml_ops
--skip_tests
--enable_reduced_operator_type_support
- --disable_types sparsetensor optional float8
+ --disable_types sparsetensor optional float4 float8
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -427,7 +427,7 @@ jobs:
--disable_ml_ops
--skip_tests
--enable_reduced_operator_type_support
- --disable_types sparsetensor optional float8
+ --disable_types sparsetensor optional float4 float8
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -483,7 +483,7 @@ jobs:
--disable_ml_ops
--skip_tests
--enable_reduced_operator_type_support
- --disable_types sparsetensor optional float8
+ --disable_types sparsetensor optional float4 float8
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
@@ -502,7 +502,7 @@ jobs:
--disable_ml_ops
--skip_tests
--enable_reduced_operator_type_support
- --disable_types sparsetensor optional float8
+ --disable_types sparsetensor optional float4 float8
--include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index baf21745e0e40..7ac4729077734 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -155,6 +155,7 @@ option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF)
option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF)
option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF)
option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF)
+option(onnxruntime_DISABLE_FLOAT4_TYPES "Disable float 4 types" OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF)
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF)
@@ -1029,6 +1030,10 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT8_TYPES)
endif()
+ if (onnxruntime_DISABLE_FLOAT4_TYPES)
+ target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT4_TYPES)
+ endif()
+
if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 660c63d056335..a6d69198fadcc 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -622,10 +622,12 @@ Do not modify directly.*
|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)|
|||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**
or
*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h
index d8822b3e452d5..7f8e94305656e 100644
--- a/include/onnxruntime/core/framework/data_types.h
+++ b/include/onnxruntime/core/framework/data_types.h
@@ -16,6 +16,7 @@
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/framework/int4.h"
+#include "core/framework/float4.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/to_tensor_proto_element_type.h"
@@ -209,6 +210,7 @@ class DataTypeImpl {
static const std::vector& AllTensorTypesIRv4();
static const std::vector& AllTensorTypesIRv9();
static const std::vector& AllTensorTypesIRv10();
+ static const std::vector& AllTensorTypesIRv11();
static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated
static const std::vector& AllFixedSizeTensorTypesIRv4();
@@ -287,6 +289,10 @@ struct IsTensorContainedType : public IsAnyOf {
};
@@ -302,6 +308,10 @@ struct IsSparseTensorContainedType : public IsAnyOf {
};
@@ -921,7 +931,7 @@ class OpaqueType : public NonTensorType {
*
* \details This class contains an integer constant that can be
* used for input data type dispatching. This class also stores the number of subelements per size units.
- * Example: For int4, the size unit is 1 byte and the number of subelements is 2.
+ * Example: For float4/int4, the size unit is 1 byte and the number of subelements is 2.
*
*/
class PrimitiveDataTypeBase : public DataTypeImpl {
@@ -1101,6 +1111,7 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
// Registers a subbyte primitive.
// Examples:
// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
+// - Float4E2M1x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Float4E2M1x2, 2)
// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
template <> \
diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h
index 4cc57ba4b5391..7c66b676117c9 100644
--- a/include/onnxruntime/core/framework/data_types_internal.h
+++ b/include/onnxruntime/core/framework/data_types_internal.h
@@ -35,7 +35,263 @@ namespace utils {
// Invoking DataTypeImpl::GetType() for switching on input types is discouraged and should be avoided.
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
-#if !defined(DISABLE_FLOAT8_TYPES)
+#if !defined(DISABLE_FLOAT8_TYPES) && !defined(DISABLE_FLOAT4_TYPES)
+
+#define DispatchOnTensorType(tensor_type, function, ...) \
+ switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ function(__VA_ARGS__); \
+ break; \
+ default: \
+ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
+ }
+
+#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
+ switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ default: \
+ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
+ }
+
+#elif !defined(DISABLE_FLOAT4_TYPES)
+
+#define DispatchOnTensorType(tensor_type, function, ...) \
+ switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ function(__VA_ARGS__); \
+ break; \
+ default: \
+ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
+ }
+
+#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
+ switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
+ retval = function(__VA_ARGS__); \
+ break; \
+ default: \
+ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
+ }
+
+#elif !defined(DISABLE_FLOAT8_TYPES)
#define DispatchOnTensorType(tensor_type, function, ...) \
switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
diff --git a/include/onnxruntime/core/framework/float4.h b/include/onnxruntime/core/framework/float4.h
new file mode 100644
index 0000000000000..3662556f53398
--- /dev/null
+++ b/include/onnxruntime/core/framework/float4.h
@@ -0,0 +1,297 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header
+// if they would like to leverage the CUDA implementation for the conversion routines
+// in their HOST code (code compiled by MSVC/GCC).
+// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h.
+// We can't include cuda.h in this header unconditionally because this header is also
+// included in core framework files which are CUDA-agnostic.
+// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines
+// implemented in this file.
+// For code compiled by NVCC which includes this header, this file will automatically
+// include cuda.h (based on the CUDA_CC macro).
+
+#pragma once
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+
+#if defined(__CUDACC__)
+// Needed for CUDA_VERSION check below
+#include
+#endif
+
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+// 'fp4_interpretation' : unreferenced parameter
+#pragma warning(disable : 4100)
+#endif
+
+#include
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+#endif
+
+#include
+#include
+#include
+#include
+
+#include "core/common/common.h"
+
+namespace onnxruntime {
+
+#if defined(__CUDACC__)
+#define ORT_HOST_DEVICE __host__ __device__
+#else
+#define ORT_HOST_DEVICE
+#endif
+
+struct Float4E2M1x2 {
+ uint8_t val_{0};
+ using UnpackedType = float;
+
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
+ using PackedCudaType = __nv_fp4x2_e2m1;
+ using PackedCudaStorageType = __nv_fp4x2_storage_t;
+#endif
+
+ private:
+ ORT_HOST_DEVICE UnpackedType Fp4ToFloatConversionCpuHelper(uint8_t fp4x2, size_t shift) const {
+ assert(shift == 0 || shift == 4);
+
+ constexpr uint8_t sign_bitmask = 0x08;
+ constexpr uint8_t exponent_bitmask = 0x06;
+ constexpr uint8_t mantissa_bitmask = 0x01;
+
+ uint8_t bits_shifted = (fp4x2 >> shift);
+
+ float sign = 1.f;
+ if (bits_shifted & sign_bitmask) {
+ sign = -1.f;
+ }
+
+ int exponent = static_cast((bits_shifted & exponent_bitmask) >> 1);
+ float mantissa = static_cast(bits_shifted & mantissa_bitmask);
+
+ return (exponent == 0) ? (sign * (mantissa / 2.f)) : (sign * (1.f + mantissa / 2.f) * static_cast(1 << (exponent - 1)));
+ }
+
+ ORT_HOST_DEVICE uint8_t FloatToFp4ConversionCpuHelper(float f, size_t shift) const {
+ assert(shift == 0 || shift == 4);
+
+ constexpr uint32_t sign_bitmask = 0x80000000;
+ constexpr uint32_t exponent_bitmask = 0x7F800000;
+ constexpr uint32_t mantissa_bitmask = 0x007FFFFF;
+ constexpr uint32_t zero = 0x00000000;
+
+ uint8_t res = 0;
+
+ uint32_t float_bits = 0;
+ std::memcpy(&float_bits, &f, sizeof(f));
+
+ // NaN always maps to +6 (irrespective of sign)
+ // https://github.com/onnx/onnx/blob/main/docs/docsgen/source/technical/float4.md
+ if (((float_bits & exponent_bitmask) == exponent_bitmask) && (float_bits & mantissa_bitmask)) {
+ return (0x07 << shift);
+ }
+
+ if (float_bits & sign_bitmask) {
+ res = 0x08;
+ }
+
+ // Infinity is sign preserving - magnitude is 6
+ if (((float_bits & exponent_bitmask) == exponent_bitmask) && ((float_bits & mantissa_bitmask) == zero)) {
+ return ((res | 0x07) << shift);
+ }
+
+ float f_abs = std::abs(f);
+ if (f_abs > 0.25 && f_abs < 0.75) {
+ res |= 0x01;
+ } else if (f_abs >= 0.75 && f_abs <= 1.25) {
+ res |= 0x02;
+ } else if (f_abs > 1.25 && f_abs < 1.75) {
+ res |= 0x03;
+ } else if (f_abs >= 1.75 && f_abs <= 2.5) {
+ res |= 0x04;
+ } else if (f_abs > 2.5 && f_abs < 3.5) {
+ res |= 0x05;
+ } else if (f_abs >= 3.5 && f_abs <= 5.0) {
+ res |= 0x06;
+ } else if (f_abs > 5.0) {
+ res |= 0x07;
+ }
+
+ return res << shift;
+ }
+
+ public:
+ Float4E2M1x2() = default;
+
+ struct FromBitsT {};
+ static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
+ constexpr ORT_HOST_DEVICE Float4E2M1x2(unsigned char bits, FromBitsT) : val_(bits) {}
+
+ inline explicit ORT_HOST_DEVICE Float4E2M1x2(UnpackedType f1, UnpackedType f2) {
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
+ float2 temp;
+ temp.x = f1;
+ temp.y = f2;
+
+ // Converts input vector of two single precision numbers packed in float2 x
+ // into a vector of two values of fp4 type of the requested kind using specified
+ // rounding mode and saturating the out-of-range values.
+ val_ = __nv_cvt_float2_to_fp4x2(temp, __NV_E2M1, cudaRoundNearest);
+#else
+ val_ = (FloatToFp4ConversionCpuHelper(f1, 0) | FloatToFp4ConversionCpuHelper(f2, 4));
+#endif
+ }
+
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
+ inline explicit ORT_HOST_DEVICE Float4E2M1x2(float2 f2) {
+ val_ = __nv_cvt_float2_to_fp4x2(f2, __NV_E2M1, cudaRoundNearest);
+ }
+
+ inline explicit ORT_HOST_DEVICE Float4E2M1x2(const __nv_fp4x2_e2m1& value) {
+ val_ = *reinterpret_cast(&value);
+ }
+
+ inline explicit ORT_HOST_DEVICE operator __nv_fp4x2_e2m1() const {
+ return *reinterpret_cast(&val_);
+ }
+
+ inline ORT_HOST_DEVICE float2 ToCudaFloat2() const {
+ return __half22float2(__nv_cvt_fp4x2_to_halfraw2(static_cast(val_), __NV_E2M1));
+ }
+#endif
+
+ inline ORT_HOST_DEVICE std::pair ToFloat2() const {
+#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
+ float2 temp = ToCudaFloat2();
+ return std::make_pair(temp.x, temp.y);
+#else
+ return std::make_pair(Fp4ToFloatConversionCpuHelper(val_, 0), Fp4ToFloatConversionCpuHelper(val_, 4));
+#endif
+ }
+
+ inline ORT_HOST_DEVICE uint8_t ToBits() const {
+ return val_;
+ }
+
+ static size_t CalcNumFloat4Pairs(size_t num_float4_elems) {
+ return (num_float4_elems + 1) / 2;
+ }
+
+ static void UnpackFloat4E2M1ToFloat(const Float4E2M1x2* fp4x2_arr,
+ UnpackedType* flt_arr, size_t size) {
+ auto src = fp4x2_arr;
+ auto dst = flt_arr;
+
+ size_t dst_i = 0;
+
+ for (; dst_i < size - 1; dst_i += 2) {
+ auto src_i = dst_i >> 1;
+ auto flt_pair = src[src_i].ToFloat2();
+ dst[dst_i] = flt_pair.first;
+ dst[dst_i + 1] = flt_pair.second;
+ }
+
+ if (dst_i < size) {
+ auto src_i = dst_i >> 1;
+ dst[dst_i] = fp4x2_arr[src_i].ToFloat2().first;
+ }
+ }
+
+ static void PackFloatToFloat4E2M1(const UnpackedType* flt_arr,
+ Float4E2M1x2* fp4x2_arr, size_t size) {
+ auto src = flt_arr;
+ auto dst = fp4x2_arr;
+
+ size_t src_i = 0;
+
+ for (; src_i < size - 1; src_i += 2) {
+ new (dst) Float4E2M1x2(src[src_i], src[src_i + 1]);
+ ++dst;
+ }
+
+ if (src_i < size) {
+ new (dst) Float4E2M1x2(src[src_i], 0);
+ }
+ }
+
+ static inline std::pair GetTensorElemIndices(size_t index) {
+ return {index >> 1, index & 0x1};
+ }
+
+ inline UnpackedType GetElem(size_t index) const {
+ assert(index <= 1);
+ auto pair = ToFloat2();
+ if (index == 0) {
+ return static_cast(pair.first);
+ }
+
+ return static_cast(pair.second);
+ }
+};
+
+inline ORT_HOST_DEVICE bool operator==(const Float4E2M1x2& left, const Float4E2M1x2& right) { return left.val_ == right.val_; }
+inline ORT_HOST_DEVICE bool operator!=(const Float4E2M1x2& left, const Float4E2M1x2& right) { return left.val_ != right.val_; }
+
+static_assert(sizeof(Float4E2M1x2) == sizeof(uint8_t));
+} // namespace onnxruntime
+
+namespace std {
+// TODO (hasesh): Does numeric_limits make sense for packed types ?
+// For now, produce limits of each element in a packed format, refine
+// this based on usage later
+template <>
+class numeric_limits {
+ public:
+ static constexpr onnxruntime::Float4E2M1x2 lowest() {
+ return onnxruntime::Float4E2M1x2(0xFF, onnxruntime::Float4E2M1x2::FromBits()); // -6.0
+ }
+
+ static constexpr onnxruntime::Float4E2M1x2 max() {
+ return onnxruntime::Float4E2M1x2(0x77, onnxruntime::Float4E2M1x2::FromBits()); // +6.0
+ }
+
+ static constexpr onnxruntime::Float4E2M1x2 min() {
+ return onnxruntime::Float4E2M1x2(0x22, onnxruntime::Float4E2M1x2::FromBits()); // +1.0
+ }
+
+ static constexpr onnxruntime::Float4E2M1x2 denorm_min() {
+ return onnxruntime::Float4E2M1x2(0x11, onnxruntime::Float4E2M1x2::FromBits()); // +0.5
+ }
+
+ static constexpr bool is_specialized = true;
+ static constexpr bool is_signed = true;
+ static constexpr bool is_integer = false;
+ static constexpr bool is_exact = false;
+ static constexpr bool has_infinity = false;
+ static constexpr bool has_quiet_NaN = false;
+ static constexpr bool has_signaling_NaN = false;
+ static constexpr auto has_denorm = true;
+ static constexpr auto has_denorm_loss = true;
+ static constexpr auto round_style = round_to_nearest;
+ static constexpr bool is_iec559 = false;
+ static constexpr bool is_bounded = true;
+ static constexpr bool is_modulo = false;
+ static constexpr int digits = 2; // (1 mantissa bit + 1 implicit bit)
+ static constexpr int digits10 = 0; // (digits -1) * std::log10(2) rounded down
+ static constexpr int max_digits10 = 1; // Mantissa bits
+ static constexpr int radix = 2;
+ static constexpr int min_exponent = 1; // 2 ^ (1-1) = 1 is the valid normalized value min ceiling we can reach
+ static constexpr int min_exponent10 = 0; // 10 ^ 0 is the valid normalized value min ceiling we can reach
+ static constexpr int max_exponent = 3; // 2 ^ (3-1) = 4 is valid normalized value max ceiling we can reach
+ static constexpr int max_exponent10 = 0; // 10 ^ 0 is the valid normalized value max ceiling we can reach
+ static constexpr auto traps = false;
+ static constexpr auto tinyness_before = false;
+};
+} // namespace std
+
+#endif // DISABLE_FLOAT4_TYPES
diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h
index 9e94cc297f782..c6cfd5a9e2c40 100644
--- a/include/onnxruntime/core/framework/float8.h
+++ b/include/onnxruntime/core/framework/float8.h
@@ -1,11 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header
+// if they would like to leverage the CUDA implementation for the conversion routines
+// in their HOST code (code compiled by MSVC/GCC).
+// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h.
+// We can't include cuda.h in this header unconditionally because this header is also
+// included in core framework files which are CUDA-agnostic.
+// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines
+// implemented in this file.
+// For code compiled by NVCC which includes this header, this file will automatically
+// include cuda.h (based on the CUDA_CC macro).
+
#pragma once
#if !defined(DISABLE_FLOAT8_TYPES)
#include "endian.h"
+
+#if defined(__CUDACC__)
+// Needed for CUDA_VERSION check below
+#include
+#endif
+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
#include "cuda_fp8.h"
#endif
diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h
index dd2603d214f63..c7f7f23f70334 100644
--- a/include/onnxruntime/core/framework/tensor.h
+++ b/include/onnxruntime/core/framework/tensor.h
@@ -284,9 +284,9 @@ class Tensor final {
///
/// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for
- /// sub-byte data types (e.g., int4).
+ /// sub-byte data types (e.g., int4/float4).
///
- /// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements.
+ /// For element types smaller than 1 byte (e.g., int4/float4), a single storage element stores multiple sub-byte elements.
/// Example: Tensor of shape (4,) has 2 storage elements.
///
/// For element types >= 1 byte, this function returns the product of the shape.
diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
index e9e28e4864a67..0a18f19b102a6 100644
--- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
+++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h
@@ -10,6 +10,7 @@
#include "core/graph/onnx_protobuf.h"
#endif
+#include "core/framework/float4.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/framework/int4.h"
@@ -98,6 +99,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType
+constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() {
+ return ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1;
+}
+#endif
+
template <>
constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() {
return ONNX_NAMESPACE::TensorProto_DataType_INT4;
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 8561de9c8c3b9..fc142abad1c70 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -206,7 +206,9 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision
// Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte)
- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte)
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, // maps to a pair of packed int4 values (size == 1 byte)
+ // Float4 types were introduced in ONNX 1.18. See https://onnx.ai/onnx/technical/float4.html
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, // maps to a pair of packed float4 values (size == 1 byte)
} ONNXTensorElementDataType;
// Synced with onnx TypeProto oneof
diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
index 586732834f0ad..8847d9d7f046e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
@@ -151,7 +151,7 @@ inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParame
// The extra memory needed if we are not using floats for the final logits.
size_t logits_sz = 0;
- if (sizeof(T) != 4) {
+ if constexpr (sizeof(T) != 4) {
logits_sz = (((total_sequence_length + 3) / 4) * 4 * sizeof(T));
}
diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc
index 06aab16f4a44b..e2192b35b5b20 100644
--- a/onnxruntime/core/framework/data_types.cc
+++ b/onnxruntime/core/framework/data_types.cc
@@ -635,6 +635,11 @@ ORT_REGISTER_TENSOR_TYPE(Float8E4M3FNUZ);
ORT_REGISTER_TENSOR_TYPE(Float8E5M2);
ORT_REGISTER_TENSOR_TYPE(Float8E5M2FNUZ);
#endif
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+ORT_REGISTER_TENSOR_TYPE(Float4E2M1x2);
+#endif
+
ORT_REGISTER_TENSOR_TYPE(Int4x2);
ORT_REGISTER_TENSOR_TYPE(UInt4x2);
@@ -812,6 +817,9 @@ void RegisterAllProtos(const std::function& reg_fn) {
REGISTER_TENSOR_PROTO(Float8E4M3FNUZ, reg_fn);
REGISTER_TENSOR_PROTO(Float8E5M2, reg_fn);
REGISTER_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn);
+#endif
+#if !defined(DISABLE_FLOAT4_TYPES)
+ REGISTER_TENSOR_PROTO(Float4E2M1x2, reg_fn);
#endif
REGISTER_TENSOR_PROTO(Int4x2, reg_fn);
REGISTER_TENSOR_PROTO(UInt4x2, reg_fn);
@@ -987,6 +995,8 @@ const char* DataTypeImpl::ToString(MLDataType type) {
return "Float8E5M2";
case TensorProto_DataType_FLOAT8E5M2FNUZ:
return "Float8E5M2FNUZ";
+ case TensorProto_DataType_FLOAT4E2M1:
+ return "Float4E2M1";
case TensorProto_DataType_INT4:
return "Int4x2";
case TensorProto_DataType_UINT4:
@@ -1048,7 +1058,6 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) {
return DataTypeImpl::GetTensorType()->AsTensorType();
#if !defined(DISABLE_FLOAT8_TYPES)
-
case TensorProto_DataType_FLOAT8E4M3FN:
return DataTypeImpl::GetTensorType()->AsTensorType();
case TensorProto_DataType_FLOAT8E4M3FNUZ:
@@ -1057,7 +1066,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) {
return DataTypeImpl::GetTensorType()->AsTensorType();
case TensorProto_DataType_FLOAT8E5M2FNUZ:
return DataTypeImpl::GetTensorType()->AsTensorType();
-
+#endif
+#if !defined(DISABLE_FLOAT4_TYPES)
+ case TensorProto_DataType_FLOAT4E2M1:
+ return DataTypeImpl::GetTensorType()->AsTensorType();
#endif
case TensorProto_DataType_INT4:
return DataTypeImpl::GetTensorType()->AsTensorType();
@@ -1209,6 +1221,13 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2);
ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ);
#endif
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+
+ORT_REGISTER_PRIM_SUBBYTE_TYPE(Float4E2M1x2, 2);
+
+#endif
+
ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2);
ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2);
@@ -1307,6 +1326,12 @@ const std::vector& DataTypeImpl::AllTensorTypesIRv10() {
return all_tensor_types;
}
+const std::vector& DataTypeImpl::AllTensorTypesIRv11() {
+ static std::vector all_tensor_types =
+ GetTensorTypesFromTypeList();
+ return all_tensor_types;
+}
+
const std::vector& DataTypeImpl::AllFixedSizeSequenceTensorTypes() {
return AllFixedSizeSequenceTensorTypesIRv4();
}
diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h
index 2478dc27162ac..2d6eb6a6be580 100644
--- a/onnxruntime/core/framework/element_type_lists.h
+++ b/onnxruntime/core/framework/element_type_lists.h
@@ -12,6 +12,7 @@
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/framework/int4.h"
+#include "core/framework/float4.h"
namespace onnxruntime {
@@ -89,6 +90,20 @@ using AllIRv10 =
UInt4x2,
Int4x2>;
+#if !defined(DISABLE_FLOAT4_TYPES)
+using AllIRv11 =
+ boost::mp11::mp_push_back<
+ AllIRv10,
+ Float4E2M1x2>;
+#else
+using AllIRv11 = AllIRv10;
+#endif
+
+// TODO: This needs upgrade to some newer version ,buit it has been
+// at this version for a while and it needs changes at the use sites
+// where-in the types in the newer IR versions are not supported.
+// This may need a sweep across multiple EPs as this is mostly used
+// for kernel registration.
using All = AllIRv4;
#if !defined(DISABLE_FLOAT8_TYPES)
@@ -100,6 +115,10 @@ using AllFloat8 =
Float8E5M2FNUZ>;
#endif
+#if !defined(DISABLE_FLOAT4_TYPES)
+using AllFloat4 = TypeList;
+#endif
+
using AllIeeeFloat =
TypeList<
float,
diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc
index d3e435c0341b0..e8eb6f4ce9d02 100644
--- a/onnxruntime/core/framework/fallback_cpu_capability.cc
+++ b/onnxruntime/core/framework/fallback_cpu_capability.cc
@@ -130,13 +130,15 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe
for (size_t i = 0; i < node->InputDefs().size(); ++i) {
auto* input = node->InputDefs()[i];
- // skip placing on CPU if the data typs is float16 or bfloat16 or float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz
+ // skip placing on CPU if the data typs is float16 or bfloat16 or
+ // float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1
if (input->Type() == DataTypeUtils::ToType("float16") ||
input->Type() == DataTypeUtils::ToType("bfloat16") ||
input->Type() == DataTypeUtils::ToType("float8e4m3fn") ||
input->Type() == DataTypeUtils::ToType("float8e4m3fnuz") ||
input->Type() == DataTypeUtils::ToType("float8e5m2") ||
- input->Type() == DataTypeUtils::ToType("float8e5m2fnuz")) {
+ input->Type() == DataTypeUtils::ToType("float8e5m2fnuz") ||
+ input->Type() == DataTypeUtils::ToType("float4e2m1")) {
place_in_cpu = false;
break;
}
diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
index 1370580bad4f6..461e82d72dc83 100644
--- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc
+++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc
@@ -84,6 +84,9 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) {
case TensorType::TensorProto_DataType_UINT4: {
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
} // maps to a pair of uint4 (size == 1 byte)
+ case TensorType::TensorProto_DataType_FLOAT4E2M1: {
+ return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1;
+ } // maps to a pair of float4 (size == 1 byte)
default: {
return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h
index c2030424ef19d..64d60e048a112 100644
--- a/onnxruntime/core/framework/print_tensor_statistics_utils.h
+++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h
@@ -94,33 +94,36 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_
}
}
-#define DEF_PRINT_COMMON_STATS_INT4(INT4_TYPE) \
- template <> \
- inline void PrintCommonStats( \
- const INT4_TYPE* data, size_t count, TensorStatisticsData&) { \
- using UnpackedType = typename INT4_TYPE::UnpackedType; \
- UnpackedType min = data[0].GetElem(0); \
- UnpackedType max = min; \
- for (size_t i = 1; i < count; i++) { \
- auto indices = INT4_TYPE::GetTensorElemIndices(i); \
- auto value = data[indices.first].GetElem(indices.second); \
- if (value > max) { \
- max = value; \
- } \
- if (value < min) { \
- min = value; \
- } \
- } \
- \
- std::cout << "Min="; \
- PrintValue(min); \
- \
- std::cout << ",Max="; \
- PrintValue(max); \
+#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \
+ template <> \
+ inline void PrintCommonStats( \
+ const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \
+ using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \
+ UnpackedType min = data[0].GetElem(0); \
+ UnpackedType max = min; \
+ for (size_t i = 1; i < count; i++) { \
+ auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \
+ auto value = data[indices.first].GetElem(indices.second); \
+ if (value > max) { \
+ max = value; \
+ } \
+ if (value < min) { \
+ min = value; \
+ } \
+ } \
+ \
+ std::cout << "Min="; \
+ PrintValue(min); \
+ \
+ std::cout << ",Max="; \
+ PrintValue(max); \
}
-DEF_PRINT_COMMON_STATS_INT4(Int4x2)
-DEF_PRINT_COMMON_STATS_INT4(UInt4x2)
+DEF_PRINT_COMMON_STATS_4BIT(Int4x2)
+DEF_PRINT_COMMON_STATS_4BIT(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2)
+#endif
template
void PrintHalfStats(const T* data, size_t count) {
diff --git a/onnxruntime/core/framework/print_tensor_utils.h b/onnxruntime/core/framework/print_tensor_utils.h
index e6af5e9e5832b..47be8b8dc2057 100644
--- a/onnxruntime/core/framework/print_tensor_utils.h
+++ b/onnxruntime/core/framework/print_tensor_utils.h
@@ -74,28 +74,31 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
std::cout << std::endl;
}
-// INT4 - Print snippet of 2D tensor with shape (dim0, dim1)
-#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(INT4_TYPE) \
- template <> \
- inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, \
- int64_t edge_items) { \
- for (int64_t i = 0; i < dim0; i++) { \
- SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
- auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- for (int64_t j = 1; j < dim1; j++) { \
- SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \
- std::cout << ", "; \
- indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
+// 4 BIT TYPE - Print snippet of 2D tensor with shape (dim0, dim1)
+#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(FOUR_BIT_TYPE) \
+ template <> \
+ inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, \
+ int64_t edge_items) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
+ auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t j = 1; j < dim1; j++) { \
+ SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \
+ std::cout << ", "; \
+ indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
}
-DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(Int4x2)
-DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(UInt4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Int4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Float4E2M1x2)
+#endif
// Print snippet of 3D tensor with shape (dim0, dim1, dim2)
template
@@ -117,32 +120,35 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t
std::cout << std::endl;
}
-// INT4 - Print snippet of 3D tensor with shape (dim0, dim1, dim2)
-#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(INT4_TYPE) \
- template <> \
- inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \
- int64_t edge_items) { \
- for (int64_t i = 0; i < dim0; i++) { \
- SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
- for (int64_t j = 0; j < dim1; j++) { \
- SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \
- auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- for (int64_t k = 1; k < dim2; k++) { \
- SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \
- std::cout << ", "; \
- indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
+// 4 BIT TYPE - Print snippet of 3D tensor with shape (dim0, dim1, dim2)
+#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(FOUR_BIT_TYPE) \
+ template <> \
+ inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \
+ int64_t edge_items) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \
+ for (int64_t j = 0; j < dim1; j++) { \
+ SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \
+ auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t k = 1; k < dim2; k++) { \
+ SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \
+ std::cout << ", "; \
+ indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
}
-DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(Int4x2)
-DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(UInt4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Int4x2)
+DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Float4E2M1x2)
+#endif
// Print 2D tensor
template
@@ -158,25 +164,28 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1) {
std::cout << std::endl;
}
-// INT4 - Print 2D tensor
-#define DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(INT4_TYPE) \
- template <> \
- inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1) { \
- for (int64_t i = 0; i < dim0; i++) { \
- auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- for (int64_t j = 1; j < dim1; j++) { \
- std::cout << ", "; \
- indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
+// 4 BIT TYPE - Print 2D tensor
+#define DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(FOUR_BIT_TYPE) \
+ template <> \
+ inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t j = 1; j < dim1; j++) { \
+ std::cout << ", "; \
+ indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
}
-DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(Int4x2)
-DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(UInt4x2)
+DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Int4x2)
+DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Float4E2M1x2)
+#endif
// Print 3D tensor
template
@@ -195,28 +204,31 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim
std::cout << std::endl;
}
-// INT4 - Print 3D tensor
-#define DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(INT4_TYPE) \
- template <> \
- inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \
- for (int64_t i = 0; i < dim0; i++) { \
- for (int64_t j = 0; j < dim1; j++) { \
- auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- for (int64_t k = 1; k < dim2; k++) { \
- std::cout << ", "; \
- indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
- PrintValue(tensor[indices.first].GetElem(indices.second)); \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
- } \
- std::cout << std::endl; \
+// 4 BIT TYPE - Print 3D tensor
+#define DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(FOUR_BIT_TYPE) \
+ template <> \
+ inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \
+ for (int64_t i = 0; i < dim0; i++) { \
+ for (int64_t j = 0; j < dim1; j++) { \
+ auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ for (int64_t k = 1; k < dim2; k++) { \
+ std::cout << ", "; \
+ indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \
+ PrintValue(tensor[indices.first].GetElem(indices.second)); \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
+ } \
+ std::cout << std::endl; \
}
-DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(Int4x2)
-DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(UInt4x2)
+DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Int4x2)
+DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Float4E2M1x2)
+#endif
template
void PrintCpuTensor(const onnxruntime::Tensor& tensor,
diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc
index 60d768cc59a5d..133f21be97c46 100644
--- a/onnxruntime/core/framework/tensor.cc
+++ b/onnxruntime/core/framework/tensor.cc
@@ -30,7 +30,7 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st
///
/// Get the number of elements for a Tensor of the given element type and shape size.
///
-/// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements.
+/// For element types smaller than 1 byte (e.g., int4/float4), a single storage element stores multiple sub-byte elements.
/// Example: Tensor of shape_size 4 has 2 storage elements.
///
/// For element types >= 1 byte, this function returns the product of the shape.
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc
index 61dcd4ea8035f..bef8df51f6d03 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.cc
+++ b/onnxruntime/core/framework/tensor_type_and_shape.cc
@@ -199,6 +199,9 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData
case o::TensorProto_DataType_UINT4:
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
break;
+ case o::TensorProto_DataType_FLOAT4E2M1:
+ type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1;
+ break;
default:
type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
break;
diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc
index ff440b595e499..e686303f3ebeb 100644
--- a/onnxruntime/core/framework/tensorprotoutils.cc
+++ b/onnxruntime/core/framework/tensorprotoutils.cc
@@ -59,7 +59,7 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) {
return t; \
}
-#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \
+#define TO_TENSOR_ORT_TYPE_4BIT_TYPE(TYPE) \
template <> \
TensorProto ToTensor(const onnxruntime::TYPE& value) { \
return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \
@@ -84,8 +84,11 @@ TO_TENSOR_ORT_TYPE(Float8E4M3FNUZ)
TO_TENSOR_ORT_TYPE(Float8E5M2)
TO_TENSOR_ORT_TYPE(Float8E5M2FNUZ)
#endif
-TO_TENSOR_ORT_TYPE_INT4(Int4x2)
-TO_TENSOR_ORT_TYPE_INT4(UInt4x2)
+#if !defined(DISABLE_FLOAT4_TYPES)
+TO_TENSOR_ORT_TYPE_4BIT_TYPE(Float4E2M1x2)
+#endif
+TO_TENSOR_ORT_TYPE_4BIT_TYPE(Int4x2)
+TO_TENSOR_ORT_TYPE_4BIT_TYPE(UInt4x2)
bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l,
const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) {
@@ -141,28 +144,32 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t
reinterpret_cast(p_data));
}
-#define DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(INT4_TYPE) \
- template <> \
- Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \
- /*out*/ INT4_TYPE* p_data) { \
- static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
- \
- ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
- \
- size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \
- ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \
- \
- gsl::span src_span = \
- gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); \
- gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \
- \
- std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
- \
- return Status::OK(); \
- }
-
-DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
-DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)
+#define DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \
+ template <> \
+ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \
+ /*out*/ FOUR_BIT_TYPE* p_data) { \
+ static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
+ \
+ ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
+ \
+ size_t num_packed_pairs = FOUR_BIT_TYPE::CalcPairFun(expected_num_elements); \
+ ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \
+ \
+ gsl::span src_span = \
+ gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); \
+ gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \
+ \
+ std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
+ \
+ return Status::OK(); \
+ }
+
+DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2, CalcNumInt4Pairs)
+DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2, CalcNumInt4Pairs)
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs)
+#endif
// Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully.
// Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr
@@ -437,31 +444,35 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor,
reinterpret_cast(p_data));
}
-#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \
- template <> \
- Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \
- const std::filesystem::path& tensor_proto_dir, \
- size_t expected_num_elements, /*out*/ INT4_TYPE* p_data) { \
- static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
- \
- ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
- std::vector unpacked_tensor; \
- ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \
- \
- size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \
- ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \
- \
- gsl::span src_span = \
- gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); \
- gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \
- \
- std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
- \
- return Status::OK(); \
- }
-
-DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2)
-DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2)
+#define DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \
+ template <> \
+ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \
+ const std::filesystem::path& tensor_proto_dir, \
+ size_t expected_num_elements, /*out*/ FOUR_BIT_TYPE* p_data) { \
+ static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \
+ \
+ ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \
+ std::vector unpacked_tensor; \
+ ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \
+ \
+ size_t num_packed_pairs = FOUR_BIT_TYPE::CalcPairFun(expected_num_elements); \
+ ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \
+ \
+ gsl::span src_span = \
+ gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); \
+ gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \
+ \
+ std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \
+ \
+ return Status::OK(); \
+ }
+
+DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2, CalcNumInt4Pairs)
+DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2, CalcNumInt4Pairs)
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs)
+#endif
#define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \
template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const std::filesystem::path&, \
@@ -843,6 +854,43 @@ DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4)
// UnpackTensor
DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4)
+#if !defined(DISABLE_FLOAT4_TYPES)
+
+template <>
+Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len,
+ /*out*/ Float4E2M1x2* p_data, size_t expected_num_elems) {
+ if (nullptr == p_data) {
+ const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size();
+ return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
+ }
+ if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1 != tensor.data_type()) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT);
+ }
+
+ size_t expected_float4_pairs = Float4E2M1x2::CalcNumFloat4Pairs(expected_num_elems);
+
+ if (raw_data != nullptr) {
+ return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data);
+ }
+
+ ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_float4_pairs,
+ "UnpackTensor: the pre-allocated size does not match the size in proto");
+
+ constexpr int max_value = std::numeric_limits::max();
+
+ for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) {
+ int v = tensor.int32_data()[i];
+ if (v < 0 || v > max_value) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow");
+ }
+ p_data[i] = Float4E2M1x2(static_cast(v), Float4E2M1x2::FromBits());
+ }
+
+ return Status::OK();
+}
+
+#endif
+
// UnpackTensor from raw data, external data or the type specific data field.
// Uses the model path to construct the full path for loading external data. In case when model_path is empty
// it uses current directory.
@@ -906,6 +954,15 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2)
} \
break;
+#if !defined(DISABLE_FLOAT4_TYPES)
+#define CASE_PROTO_TRACE_FLOAT4(X, Y) \
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \
+ if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumFloat4Pairs(size), sizeof(Y), out)) { \
+ return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \
+ } \
+ break;
+#endif
+
template
common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) {
const auto size = narrow(shape.Size());
@@ -932,6 +989,10 @@ common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, in
#endif
CASE_PROTO_TRACE_INT4(UINT4, UInt4x2);
CASE_PROTO_TRACE_INT4(INT4, Int4x2);
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+ CASE_PROTO_TRACE_FLOAT4(FLOAT4E2M1, Float4E2M1x2);
+#endif
default:
return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED);
}
@@ -1322,6 +1383,11 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa
#endif
CASE_PROTO(INT4, Int4x2);
CASE_PROTO(UINT4, UInt4x2);
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+ CASE_PROTO(FLOAT4E2M1, Float4E2M1x2);
+#endif
+
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING:
ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len,
static_cast(preallocated),
@@ -1402,6 +1468,11 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) {
#endif
CASE_TYPE(UINT4)
CASE_TYPE(INT4)
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+ CASE_TYPE(FLOAT4E2M1)
+#endif
+
default:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
}
@@ -1957,11 +2028,11 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T
break; \
}
-#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \
+#define CASE_UNPACK_4BIT_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PAIR_FUN) \
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \
TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \
size_t element_count = static_cast(tensor_shape.Size()); \
- size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \
+ size_t packed_element_count = ELEMENT_TYPE::CALC_PAIR_FUN(element_count); \
unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \
return onnxruntime::utils::UnpackTensor(initializer, \
initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \
@@ -2004,8 +2075,13 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer,
CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size);
CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size);
#endif
- CASE_UNPACK_INT4(INT4, Int4x2, int32_data_size);
- CASE_UNPACK_INT4(UINT4, UInt4x2, int32_data_size);
+ CASE_UNPACK_4BIT_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs);
+ CASE_UNPACK_4BIT_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs);
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+ CASE_UNPACK_4BIT_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs);
+#endif
+
default:
break;
}
diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h
index 6b5c404e26b7f..3a23093a5445b 100644
--- a/onnxruntime/core/framework/utils.h
+++ b/onnxruntime/core/framework/utils.h
@@ -229,6 +229,13 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4;
}
+#if !defined(DISABLE_FLOAT4_TYPES)
+template <>
+constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() {
+ return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1;
+}
+#endif
+
int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType);
#ifdef ENABLE_TRAINING
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index 9123c0bd76ec7..5fe700e1711a4 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#include
#include
diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h
index 0b56cac1038e4..bef0559d967d0 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.h
+++ b/onnxruntime/core/providers/cuda/cuda_common.h
@@ -13,6 +13,7 @@
#include "core/common/status.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
+#include "core/framework/float4.h"
#include "core/providers/cuda/cuda_pch.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/shared_inc/fast_divmod.h"
@@ -99,6 +100,15 @@ class ToCudaType {
#endif
+#if defined(ENABLE_FP4) && !defined(DISABLE_FLOAT4_TYPES)
+// ENABLE_FP4 is only set if CUDA SDK version is >= 12.8
+template <>
+class ToCudaType {
+ public:
+ typedef Float4E2M1x2::PackedCudaType MappedType;
+};
+#endif
+
inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) {
int stride = 1;
if (dims.empty() || p.size() < dims.size())
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index e036c7764d041..7c1b6e8f393cf 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1392,22 +1392,22 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize);
// Opset 19
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, BFloat16, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int16_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int64_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint16_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint32_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint64_t, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, bool, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, BFloat16, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int16_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int32_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int64_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint16_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint32_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint64_t, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Cast);
#if !defined(DISABLE_FLOAT8_TYPES)
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, Cast);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, Cast);
#endif
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear);
@@ -1491,6 +1491,27 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int8_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int16_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int32_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int64_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint8_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint16_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint32_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint64_t, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, bool, Cast);
+#if !defined(DISABLE_FLOAT8_TYPES)
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast);
+#endif
+
+#if !defined(DISABLE_FLOAT4_TYPES)
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast);
+#endif
#endif
@@ -2384,23 +2405,23 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- // Opset 19
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ // Opset 19-20
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
BuildKernelCreateInfo,
@@ -2484,6 +2505,26 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo