From 1c3bacdab238107557ce689c156fb9968dcce4d5 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 3 Sep 2025 11:33:49 -0700 Subject: [PATCH 1/3] Support fp4 type in ORT (#25767) https://github.com/onnx/onnx/pull/6318 and https://github.com/onnx/onnx/pull/6283 added FP4 support to ONNX. This change introduces the FP4 type in ORT and adds type support to one relevant operator (`Cast`) as a proof-of-concept for the type integration into ORT. More op support will be added on a need-basis. This change took inspiration from the following PRs: https://github.com/microsoft/onnxruntime/pull/14731 https://github.com/microsoft/onnxruntime/pull/22228 https://github.com/microsoft/onnxruntime/pull/20362 Some notes: 1) Only `tensor` type gets support for FP4 initially. Secondary types like `seq(tensor)`, `sparse_tensor`, `optional` do not get support (so as to not introduce unnecessary bloat to the framework without a solid use-case) 2) Flatbuffer related files receive no updates in this PR Be able to run FP4 models with ORT --- .github/workflows/linux_minimal_build.yml | 14 +- cmake/CMakeLists.txt | 5 + docs/OperatorKernels.md | 10 +- .../onnxruntime/core/framework/data_types.h | 13 +- .../core/framework/data_types_internal.h | 258 +++++++++++++- include/onnxruntime/core/framework/float4.h | 297 ++++++++++++++++ include/onnxruntime/core/framework/float8.h | 17 + include/onnxruntime/core/framework/tensor.h | 4 +- .../framework/to_tensor_proto_element_type.h | 9 + .../core/session/onnxruntime_c_api.h | 4 +- ...er_masked_multihead_attention_impl_utils.h | 2 +- onnxruntime/core/framework/data_types.cc | 29 +- .../core/framework/element_type_lists.h | 19 + .../core/framework/fallback_cpu_capability.cc | 6 +- .../framework/onnxruntime_map_type_info.cc | 3 + onnxruntime/core/framework/tensor.cc | 2 +- .../core/framework/tensor_type_and_shape.cc | 3 + .../core/framework/tensorprotoutils.cc | 184 +++++++--- onnxruntime/core/framework/utils.h | 7 + .../core/providers/cuda/cu_inc/common.cuh | 1 + onnxruntime/core/providers/cuda/cuda_common.h | 10 + .../providers/cuda/cuda_execution_provider.cc | 103 ++++-- .../providers/cuda/cuda_type_conversion.h | 11 + .../core/providers/cuda/tensor/cast_op.cc | 333 +++++++++++------- .../core/providers/cuda/tensor/cast_op.cu | 191 +++++++++- .../core/providers/cuda/tensor/cast_op.h | 5 + .../providers/shared_library/provider_api.h | 11 +- .../provider_bridge_provider.cc | 10 + .../shared_library/provider_interfaces.h | 18 + .../shared_library/provider_wrappedtypes.h | 17 + .../core/session/provider_bridge_ort.cc | 29 ++ .../test/contrib_ops/gemm_float8_test.cc | 10 +- onnxruntime/test/framework/float_4_test.cc | 158 +++++++++ onnxruntime/test/onnx/tensorprotoutils.cc | 85 +++++ onnxruntime/test/providers/base_tester.h | 11 +- onnxruntime/test/providers/checkers.cc | 48 ++- .../test/providers/cpu/tensor/cast_op_test.cc | 118 ++++++- tools/ci_build/build.py | 3 + tools/ci_build/build_args.py | 2 +- 39 files changed, 1814 insertions(+), 246 deletions(-) create mode 100644 include/onnxruntime/core/framework/float4.h create mode 100644 onnxruntime/test/framework/float_4_test.cc diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 92cdbb70e9858..93c201c3c6d60 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -325,7 +325,7 @@ jobs: --build_wheel --use_binskim_compliant_compile_flags --disable_ml_ops - --disable_types sparsetensor float8 optional + --disable_types sparsetensor float4 float8 optional --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -341,7 +341,7 @@ jobs: --build_wheel --use_binskim_compliant_compile_flags --disable_ml_ops - --disable_types sparsetensor float8 optional + --disable_types sparsetensor float4 float8 optional --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -358,7 +358,7 @@ jobs: --build_wheel --use_binskim_compliant_compile_flags --disable_ml_ops - --disable_types sparsetensor float8 optional + --disable_types sparsetensor float4 float8 optional --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -408,7 +408,7 @@ jobs: --disable_ml_ops --skip_tests --enable_reduced_operator_type_support - --disable_types sparsetensor optional float8 + --disable_types sparsetensor optional float4 float8 --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -427,7 +427,7 @@ jobs: --disable_ml_ops --skip_tests --enable_reduced_operator_type_support - --disable_types sparsetensor optional float8 + --disable_types sparsetensor optional float4 float8 --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -483,7 +483,7 @@ jobs: --disable_ml_ops --skip_tests --enable_reduced_operator_type_support - --disable_types sparsetensor optional float8 + --disable_types sparsetensor optional float4 float8 --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF @@ -502,7 +502,7 @@ jobs: --disable_ml_ops --skip_tests --enable_reduced_operator_type_support - --disable_types sparsetensor optional float8 + --disable_types sparsetensor optional float4 float8 --include_ops_by_config /onnxruntime_src/build/.test_data/include_no_operators.config --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index baf21745e0e40..7ac4729077734 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -155,6 +155,7 @@ option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF) option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF) option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) +option(onnxruntime_DISABLE_FLOAT4_TYPES "Disable float 4 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) @@ -1029,6 +1030,10 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT8_TYPES) endif() + if (onnxruntime_DISABLE_FLOAT4_TYPES) + target_compile_definitions(${target_name} PRIVATE DISABLE_FLOAT4_TYPES) + endif() + if (onnxruntime_ENABLE_ATEN) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 660c63d056335..a6d69198fadcc 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -622,10 +622,12 @@ Do not modify directly.* |||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| |||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 12]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[6, 8]|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Clip|*in* input:**T**
*in* min:**T**
*in* max:**T**
*out* output:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int64), tensor(int8), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/framework/data_types.h b/include/onnxruntime/core/framework/data_types.h index d8822b3e452d5..7f8e94305656e 100644 --- a/include/onnxruntime/core/framework/data_types.h +++ b/include/onnxruntime/core/framework/data_types.h @@ -16,6 +16,7 @@ #include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/framework/int4.h" +#include "core/framework/float4.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -209,6 +210,7 @@ class DataTypeImpl { static const std::vector& AllTensorTypesIRv4(); static const std::vector& AllTensorTypesIRv9(); static const std::vector& AllTensorTypesIRv10(); + static const std::vector& AllTensorTypesIRv11(); static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated static const std::vector& AllFixedSizeTensorTypesIRv4(); @@ -287,6 +289,10 @@ struct IsTensorContainedType : public IsAnyOf { }; @@ -302,6 +308,10 @@ struct IsSparseTensorContainedType : public IsAnyOf { }; @@ -921,7 +931,7 @@ class OpaqueType : public NonTensorType { * * \details This class contains an integer constant that can be * used for input data type dispatching. This class also stores the number of subelements per size units. - * Example: For int4, the size unit is 1 byte and the number of subelements is 2. + * Example: For float4/int4, the size unit is 1 byte and the number of subelements is 2. * */ class PrimitiveDataTypeBase : public DataTypeImpl { @@ -1101,6 +1111,7 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const { // Registers a subbyte primitive. // Examples: // - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2) +// - Float4E2M1x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Float4E2M1x2, 2) // - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8) #define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \ template <> \ diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 4cc57ba4b5391..7c66b676117c9 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -35,7 +35,263 @@ namespace utils { // Invoking DataTypeImpl::GetType() for switching on input types is discouraged and should be avoided. // Every primitive type carries with it an integer constant that can be used for quick switching on types. -#if !defined(DISABLE_FLOAT8_TYPES) +#if !defined(DISABLE_FLOAT8_TYPES) && !defined(DISABLE_FLOAT4_TYPES) + +#define DispatchOnTensorType(tensor_type, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#elif !defined(DISABLE_FLOAT4_TYPES) + +#define DispatchOnTensorType(tensor_type, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \ + switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_STRING: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT4: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ + retval = function(__VA_ARGS__); \ + break; \ + default: \ + ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ + } + +#elif !defined(DISABLE_FLOAT8_TYPES) #define DispatchOnTensorType(tensor_type, function, ...) \ switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \ diff --git a/include/onnxruntime/core/framework/float4.h b/include/onnxruntime/core/framework/float4.h new file mode 100644 index 0000000000000..3662556f53398 --- /dev/null +++ b/include/onnxruntime/core/framework/float4.h @@ -0,0 +1,297 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header +// if they would like to leverage the CUDA implementation for the conversion routines +// in their HOST code (code compiled by MSVC/GCC). +// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h. +// We can't include cuda.h in this header unconditionally because this header is also +// included in core framework files which are CUDA-agnostic. +// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines +// implemented in this file. +// For code compiled by NVCC which includes this header, this file will automatically +// include cuda.h (based on the CUDA_CC macro). + +#pragma once + +#if !defined(DISABLE_FLOAT4_TYPES) + +#if defined(__CUDACC__) +// Needed for CUDA_VERSION check below +#include +#endif + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + +#if defined(_MSC_VER) +#pragma warning(push) +// 'fp4_interpretation' : unreferenced parameter +#pragma warning(disable : 4100) +#endif + +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#endif + +#include +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime { + +#if defined(__CUDACC__) +#define ORT_HOST_DEVICE __host__ __device__ +#else +#define ORT_HOST_DEVICE +#endif + +struct Float4E2M1x2 { + uint8_t val_{0}; + using UnpackedType = float; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + using PackedCudaType = __nv_fp4x2_e2m1; + using PackedCudaStorageType = __nv_fp4x2_storage_t; +#endif + + private: + ORT_HOST_DEVICE UnpackedType Fp4ToFloatConversionCpuHelper(uint8_t fp4x2, size_t shift) const { + assert(shift == 0 || shift == 4); + + constexpr uint8_t sign_bitmask = 0x08; + constexpr uint8_t exponent_bitmask = 0x06; + constexpr uint8_t mantissa_bitmask = 0x01; + + uint8_t bits_shifted = (fp4x2 >> shift); + + float sign = 1.f; + if (bits_shifted & sign_bitmask) { + sign = -1.f; + } + + int exponent = static_cast((bits_shifted & exponent_bitmask) >> 1); + float mantissa = static_cast(bits_shifted & mantissa_bitmask); + + return (exponent == 0) ? (sign * (mantissa / 2.f)) : (sign * (1.f + mantissa / 2.f) * static_cast(1 << (exponent - 1))); + } + + ORT_HOST_DEVICE uint8_t FloatToFp4ConversionCpuHelper(float f, size_t shift) const { + assert(shift == 0 || shift == 4); + + constexpr uint32_t sign_bitmask = 0x80000000; + constexpr uint32_t exponent_bitmask = 0x7F800000; + constexpr uint32_t mantissa_bitmask = 0x007FFFFF; + constexpr uint32_t zero = 0x00000000; + + uint8_t res = 0; + + uint32_t float_bits = 0; + std::memcpy(&float_bits, &f, sizeof(f)); + + // NaN always maps to +6 (irrespective of sign) + // https://github.com/onnx/onnx/blob/main/docs/docsgen/source/technical/float4.md + if (((float_bits & exponent_bitmask) == exponent_bitmask) && (float_bits & mantissa_bitmask)) { + return (0x07 << shift); + } + + if (float_bits & sign_bitmask) { + res = 0x08; + } + + // Infinity is sign preserving - magnitude is 6 + if (((float_bits & exponent_bitmask) == exponent_bitmask) && ((float_bits & mantissa_bitmask) == zero)) { + return ((res | 0x07) << shift); + } + + float f_abs = std::abs(f); + if (f_abs > 0.25 && f_abs < 0.75) { + res |= 0x01; + } else if (f_abs >= 0.75 && f_abs <= 1.25) { + res |= 0x02; + } else if (f_abs > 1.25 && f_abs < 1.75) { + res |= 0x03; + } else if (f_abs >= 1.75 && f_abs <= 2.5) { + res |= 0x04; + } else if (f_abs > 2.5 && f_abs < 3.5) { + res |= 0x05; + } else if (f_abs >= 3.5 && f_abs <= 5.0) { + res |= 0x06; + } else if (f_abs > 5.0) { + res |= 0x07; + } + + return res << shift; + } + + public: + Float4E2M1x2() = default; + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float4E2M1x2(unsigned char bits, FromBitsT) : val_(bits) {} + + inline explicit ORT_HOST_DEVICE Float4E2M1x2(UnpackedType f1, UnpackedType f2) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + float2 temp; + temp.x = f1; + temp.y = f2; + + // Converts input vector of two single precision numbers packed in float2 x + // into a vector of two values of fp4 type of the requested kind using specified + // rounding mode and saturating the out-of-range values. + val_ = __nv_cvt_float2_to_fp4x2(temp, __NV_E2M1, cudaRoundNearest); +#else + val_ = (FloatToFp4ConversionCpuHelper(f1, 0) | FloatToFp4ConversionCpuHelper(f2, 4)); +#endif + } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + inline explicit ORT_HOST_DEVICE Float4E2M1x2(float2 f2) { + val_ = __nv_cvt_float2_to_fp4x2(f2, __NV_E2M1, cudaRoundNearest); + } + + inline explicit ORT_HOST_DEVICE Float4E2M1x2(const __nv_fp4x2_e2m1& value) { + val_ = *reinterpret_cast(&value); + } + + inline explicit ORT_HOST_DEVICE operator __nv_fp4x2_e2m1() const { + return *reinterpret_cast(&val_); + } + + inline ORT_HOST_DEVICE float2 ToCudaFloat2() const { + return __half22float2(__nv_cvt_fp4x2_to_halfraw2(static_cast(val_), __NV_E2M1)); + } +#endif + + inline ORT_HOST_DEVICE std::pair ToFloat2() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + float2 temp = ToCudaFloat2(); + return std::make_pair(temp.x, temp.y); +#else + return std::make_pair(Fp4ToFloatConversionCpuHelper(val_, 0), Fp4ToFloatConversionCpuHelper(val_, 4)); +#endif + } + + inline ORT_HOST_DEVICE uint8_t ToBits() const { + return val_; + } + + static size_t CalcNumFloat4Pairs(size_t num_float4_elems) { + return (num_float4_elems + 1) / 2; + } + + static void UnpackFloat4E2M1ToFloat(const Float4E2M1x2* fp4x2_arr, + UnpackedType* flt_arr, size_t size) { + auto src = fp4x2_arr; + auto dst = flt_arr; + + size_t dst_i = 0; + + for (; dst_i < size - 1; dst_i += 2) { + auto src_i = dst_i >> 1; + auto flt_pair = src[src_i].ToFloat2(); + dst[dst_i] = flt_pair.first; + dst[dst_i + 1] = flt_pair.second; + } + + if (dst_i < size) { + auto src_i = dst_i >> 1; + dst[dst_i] = fp4x2_arr[src_i].ToFloat2().first; + } + } + + static void PackFloatToFloat4E2M1(const UnpackedType* flt_arr, + Float4E2M1x2* fp4x2_arr, size_t size) { + auto src = flt_arr; + auto dst = fp4x2_arr; + + size_t src_i = 0; + + for (; src_i < size - 1; src_i += 2) { + new (dst) Float4E2M1x2(src[src_i], src[src_i + 1]); + ++dst; + } + + if (src_i < size) { + new (dst) Float4E2M1x2(src[src_i], 0); + } + } + + static inline std::pair GetTensorElemIndices(size_t index) { + return {index >> 1, index & 0x1}; + } + + inline UnpackedType GetElem(size_t index) const { + assert(index <= 1); + auto pair = ToFloat2(); + if (index == 0) { + return static_cast(pair.first); + } + + return static_cast(pair.second); + } +}; + +inline ORT_HOST_DEVICE bool operator==(const Float4E2M1x2& left, const Float4E2M1x2& right) { return left.val_ == right.val_; } +inline ORT_HOST_DEVICE bool operator!=(const Float4E2M1x2& left, const Float4E2M1x2& right) { return left.val_ != right.val_; } + +static_assert(sizeof(Float4E2M1x2) == sizeof(uint8_t)); +} // namespace onnxruntime + +namespace std { +// TODO (hasesh): Does numeric_limits make sense for packed types ? +// For now, produce limits of each element in a packed format, refine +// this based on usage later +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float4E2M1x2 lowest() { + return onnxruntime::Float4E2M1x2(0xFF, onnxruntime::Float4E2M1x2::FromBits()); // -6.0 + } + + static constexpr onnxruntime::Float4E2M1x2 max() { + return onnxruntime::Float4E2M1x2(0x77, onnxruntime::Float4E2M1x2::FromBits()); // +6.0 + } + + static constexpr onnxruntime::Float4E2M1x2 min() { + return onnxruntime::Float4E2M1x2(0x22, onnxruntime::Float4E2M1x2::FromBits()); // +1.0 + } + + static constexpr onnxruntime::Float4E2M1x2 denorm_min() { + return onnxruntime::Float4E2M1x2(0x11, onnxruntime::Float4E2M1x2::FromBits()); // +0.5 + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 2; // (1 mantissa bit + 1 implicit bit) + static constexpr int digits10 = 0; // (digits -1) * std::log10(2) rounded down + static constexpr int max_digits10 = 1; // Mantissa bits + static constexpr int radix = 2; + static constexpr int min_exponent = 1; // 2 ^ (1-1) = 1 is the valid normalized value min ceiling we can reach + static constexpr int min_exponent10 = 0; // 10 ^ 0 is the valid normalized value min ceiling we can reach + static constexpr int max_exponent = 3; // 2 ^ (3-1) = 4 is valid normalized value max ceiling we can reach + static constexpr int max_exponent10 = 0; // 10 ^ 0 is the valid normalized value max ceiling we can reach + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; +} // namespace std + +#endif // DISABLE_FLOAT4_TYPES diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 9e94cc297f782..c6cfd5a9e2c40 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -1,11 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header +// if they would like to leverage the CUDA implementation for the conversion routines +// in their HOST code (code compiled by MSVC/GCC). +// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h. +// We can't include cuda.h in this header unconditionally because this header is also +// included in core framework files which are CUDA-agnostic. +// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines +// implemented in this file. +// For code compiled by NVCC which includes this header, this file will automatically +// include cuda.h (based on the CUDA_CC macro). + #pragma once #if !defined(DISABLE_FLOAT8_TYPES) #include "endian.h" + +#if defined(__CUDACC__) +// Needed for CUDA_VERSION check below +#include +#endif + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 #include "cuda_fp8.h" #endif diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index dd2603d214f63..c7f7f23f70334 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -284,9 +284,9 @@ class Tensor final { /// /// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for - /// sub-byte data types (e.g., int4). + /// sub-byte data types (e.g., int4/float4). /// - /// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. + /// For element types smaller than 1 byte (e.g., int4/float4), a single storage element stores multiple sub-byte elements. /// Example: Tensor of shape (4,) has 2 storage elements. /// /// For element types >= 1 byte, this function returns the product of the shape. diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h index e9e28e4864a67..0a18f19b102a6 100644 --- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h +++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h @@ -10,6 +10,7 @@ #include "core/graph/onnx_protobuf.h" #endif +#include "core/framework/float4.h" #include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/framework/int4.h" @@ -98,6 +99,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1; +} +#endif + template <> constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { return ONNX_NAMESPACE::TensorProto_DataType_INT4; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8561de9c8c3b9..fc142abad1c70 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -206,7 +206,9 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, // maps to a pair of packed int4 values (size == 1 byte) + // Float4 types were introduced in ONNX 1.18. See https://onnx.ai/onnx/technical/float4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, // maps to a pair of packed float4 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h index 586732834f0ad..8847d9d7f046e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h @@ -151,7 +151,7 @@ inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParame // The extra memory needed if we are not using floats for the final logits. size_t logits_sz = 0; - if (sizeof(T) != 4) { + if constexpr (sizeof(T) != 4) { logits_sz = (((total_sequence_length + 3) / 4) * 4 * sizeof(T)); } diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 06aab16f4a44b..e2192b35b5b20 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -635,6 +635,11 @@ ORT_REGISTER_TENSOR_TYPE(Float8E4M3FNUZ); ORT_REGISTER_TENSOR_TYPE(Float8E5M2); ORT_REGISTER_TENSOR_TYPE(Float8E5M2FNUZ); #endif + +#if !defined(DISABLE_FLOAT4_TYPES) +ORT_REGISTER_TENSOR_TYPE(Float4E2M1x2); +#endif + ORT_REGISTER_TENSOR_TYPE(Int4x2); ORT_REGISTER_TENSOR_TYPE(UInt4x2); @@ -812,6 +817,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_TENSOR_PROTO(Float8E4M3FNUZ, reg_fn); REGISTER_TENSOR_PROTO(Float8E5M2, reg_fn); REGISTER_TENSOR_PROTO(Float8E5M2FNUZ, reg_fn); +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + REGISTER_TENSOR_PROTO(Float4E2M1x2, reg_fn); #endif REGISTER_TENSOR_PROTO(Int4x2, reg_fn); REGISTER_TENSOR_PROTO(UInt4x2, reg_fn); @@ -987,6 +995,8 @@ const char* DataTypeImpl::ToString(MLDataType type) { return "Float8E5M2"; case TensorProto_DataType_FLOAT8E5M2FNUZ: return "Float8E5M2FNUZ"; + case TensorProto_DataType_FLOAT4E2M1: + return "Float4E2M1"; case TensorProto_DataType_INT4: return "Int4x2"; case TensorProto_DataType_UINT4: @@ -1048,7 +1058,6 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); #if !defined(DISABLE_FLOAT8_TYPES) - case TensorProto_DataType_FLOAT8E4M3FN: return DataTypeImpl::GetTensorType()->AsTensorType(); case TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -1057,7 +1066,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); case TensorProto_DataType_FLOAT8E5M2FNUZ: return DataTypeImpl::GetTensorType()->AsTensorType(); - +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + case TensorProto_DataType_FLOAT4E2M1: + return DataTypeImpl::GetTensorType()->AsTensorType(); #endif case TensorProto_DataType_INT4: return DataTypeImpl::GetTensorType()->AsTensorType(); @@ -1209,6 +1221,13 @@ ORT_REGISTER_PRIM_TYPE(Float8E5M2); ORT_REGISTER_PRIM_TYPE(Float8E5M2FNUZ); #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + +ORT_REGISTER_PRIM_SUBBYTE_TYPE(Float4E2M1x2, 2); + +#endif + ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2); ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2); @@ -1307,6 +1326,12 @@ const std::vector& DataTypeImpl::AllTensorTypesIRv10() { return all_tensor_types; } +const std::vector& DataTypeImpl::AllTensorTypesIRv11() { + static std::vector all_tensor_types = + GetTensorTypesFromTypeList(); + return all_tensor_types; +} + const std::vector& DataTypeImpl::AllFixedSizeSequenceTensorTypes() { return AllFixedSizeSequenceTensorTypesIRv4(); } diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h index 2478dc27162ac..2d6eb6a6be580 100644 --- a/onnxruntime/core/framework/element_type_lists.h +++ b/onnxruntime/core/framework/element_type_lists.h @@ -12,6 +12,7 @@ #include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/framework/int4.h" +#include "core/framework/float4.h" namespace onnxruntime { @@ -89,6 +90,20 @@ using AllIRv10 = UInt4x2, Int4x2>; +#if !defined(DISABLE_FLOAT4_TYPES) +using AllIRv11 = + boost::mp11::mp_push_back< + AllIRv10, + Float4E2M1x2>; +#else +using AllIRv11 = AllIRv10; +#endif + +// TODO: This needs upgrade to some newer version ,buit it has been +// at this version for a while and it needs changes at the use sites +// where-in the types in the newer IR versions are not supported. +// This may need a sweep across multiple EPs as this is mostly used +// for kernel registration. using All = AllIRv4; #if !defined(DISABLE_FLOAT8_TYPES) @@ -100,6 +115,10 @@ using AllFloat8 = Float8E5M2FNUZ>; #endif +#if !defined(DISABLE_FLOAT4_TYPES) +using AllFloat4 = TypeList; +#endif + using AllIeeeFloat = TypeList< float, diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index d3e435c0341b0..e8eb6f4ce9d02 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -130,13 +130,15 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe for (size_t i = 0; i < node->InputDefs().size(); ++i) { auto* input = node->InputDefs()[i]; - // skip placing on CPU if the data typs is float16 or bfloat16 or float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz + // skip placing on CPU if the data typs is float16 or bfloat16 or + // float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1 if (input->Type() == DataTypeUtils::ToType("float16") || input->Type() == DataTypeUtils::ToType("bfloat16") || input->Type() == DataTypeUtils::ToType("float8e4m3fn") || input->Type() == DataTypeUtils::ToType("float8e4m3fnuz") || input->Type() == DataTypeUtils::ToType("float8e5m2") || - input->Type() == DataTypeUtils::ToType("float8e5m2fnuz")) { + input->Type() == DataTypeUtils::ToType("float8e5m2fnuz") || + input->Type() == DataTypeUtils::ToType("float4e2m1")) { place_in_cpu = false; break; } diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 1370580bad4f6..461e82d72dc83 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -84,6 +84,9 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { case TensorType::TensorProto_DataType_UINT4: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } // maps to a pair of uint4 (size == 1 byte) + case TensorType::TensorProto_DataType_FLOAT4E2M1: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; + } // maps to a pair of float4 (size == 1 byte) default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 60d768cc59a5d..133f21be97c46 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -30,7 +30,7 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st /// /// Get the number of elements for a Tensor of the given element type and shape size. /// -/// For element types smaller than 1 byte (e.g., int4), a single storage element stores multiple sub-byte elements. +/// For element types smaller than 1 byte (e.g., int4/float4), a single storage element stores multiple sub-byte elements. /// Example: Tensor of shape_size 4 has 2 storage elements. /// /// For element types >= 1 byte, this function returns the product of the shape. diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 61dcd4ea8035f..bef8df51f6d03 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -199,6 +199,9 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData case o::TensorProto_DataType_UINT4: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; break; + case o::TensorProto_DataType_FLOAT4E2M1: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; + break; default: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; break; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ff440b595e499..e686303f3ebeb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -59,7 +59,7 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } -#define TO_TENSOR_ORT_TYPE_INT4(TYPE) \ +#define TO_TENSOR_ORT_TYPE_4BIT_TYPE(TYPE) \ template <> \ TensorProto ToTensor(const onnxruntime::TYPE& value) { \ return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \ @@ -84,8 +84,11 @@ TO_TENSOR_ORT_TYPE(Float8E4M3FNUZ) TO_TENSOR_ORT_TYPE(Float8E5M2) TO_TENSOR_ORT_TYPE(Float8E5M2FNUZ) #endif -TO_TENSOR_ORT_TYPE_INT4(Int4x2) -TO_TENSOR_ORT_TYPE_INT4(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +TO_TENSOR_ORT_TYPE_4BIT_TYPE(Float4E2M1x2) +#endif +TO_TENSOR_ORT_TYPE_4BIT_TYPE(Int4x2) +TO_TENSOR_ORT_TYPE_4BIT_TYPE(UInt4x2) bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) { @@ -141,28 +144,32 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t reinterpret_cast(p_data)); } -#define DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(INT4_TYPE) \ - template <> \ - Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \ - /*out*/ INT4_TYPE* p_data) { \ - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ - \ - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ - \ - size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ - ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \ - \ - gsl::span src_span = \ - gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); \ - gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \ - \ - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ - \ - return Status::OK(); \ - } - -DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) -DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) +#define DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \ + template <> \ + Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, \ + /*out*/ FOUR_BIT_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + \ + size_t num_packed_pairs = FOUR_BIT_TYPE::CalcPairFun(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == raw_data_len, "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = \ + gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2, CalcNumInt4Pairs) +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) + +#if !defined(DISABLE_FLOAT4_TYPES) +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) +#endif // Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully. // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr @@ -437,31 +444,35 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } -#define DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(INT4_TYPE) \ - template <> \ - Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ - const std::filesystem::path& tensor_proto_dir, \ - size_t expected_num_elements, /*out*/ INT4_TYPE* p_data) { \ - static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ - \ - ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ - std::vector unpacked_tensor; \ - ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ - \ - size_t num_packed_pairs = INT4_TYPE::CalcNumInt4Pairs(expected_num_elements); \ - ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ - \ - gsl::span src_span = \ - gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); \ - gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ - \ - std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ - \ - return Status::OK(); \ - } - -DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2) -DEFINE_INT4_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2) +#define DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \ + template <> \ + Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ + const std::filesystem::path& tensor_proto_dir, \ + size_t expected_num_elements, /*out*/ FOUR_BIT_TYPE* p_data) { \ + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); \ + \ + ORT_RETURN_IF(nullptr == p_data, "nullptr == p_data"); \ + std::vector unpacked_tensor; \ + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(tensor, tensor_proto_dir, unpacked_tensor)); \ + \ + size_t num_packed_pairs = FOUR_BIT_TYPE::CalcPairFun(expected_num_elements); \ + ORT_RETURN_IF_NOT(num_packed_pairs == unpacked_tensor.size(), "Unexpected number of packed int4 pairs"); \ + \ + gsl::span src_span = \ + gsl::make_span(reinterpret_cast(unpacked_tensor.data()), num_packed_pairs); \ + gsl::span dst_span = gsl::make_span(p_data, expected_num_elements); \ + \ + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); \ + \ + return Status::OK(); \ + } + +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2, CalcNumInt4Pairs) +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) + +#if !defined(DISABLE_FLOAT4_TYPES) +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) +#endif #define INSTANTIATE_UNPACK_EXTERNAL_TENSOR(type) \ template Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto&, const std::filesystem::path&, \ @@ -843,6 +854,43 @@ DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4) // UnpackTensor DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) +#if !defined(DISABLE_FLOAT4_TYPES) + +template <> +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, + /*out*/ Float4E2M1x2* p_data, size_t expected_num_elems) { + if (nullptr == p_data) { + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1 != tensor.data_type()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + + size_t expected_float4_pairs = Float4E2M1x2::CalcNumFloat4Pairs(expected_num_elems); + + if (raw_data != nullptr) { + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); + } + + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_float4_pairs, + "UnpackTensor: the pre-allocated size does not match the size in proto"); + + constexpr int max_value = std::numeric_limits::max(); + + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { + int v = tensor.int32_data()[i]; + if (v < 0 || v > max_value) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "data overflow"); + } + p_data[i] = Float4E2M1x2(static_cast(v), Float4E2M1x2::FromBits()); + } + + return Status::OK(); +} + +#endif + // UnpackTensor from raw data, external data or the type specific data field. // Uses the model path to construct the full path for loading external data. In case when model_path is empty // it uses current directory. @@ -906,6 +954,15 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2) } \ break; +#if !defined(DISABLE_FLOAT4_TYPES) +#define CASE_PROTO_TRACE_FLOAT4(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumFloat4Pairs(size), sizeof(Y), out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ + break; +#endif + template common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { const auto size = narrow(shape.Size()); @@ -932,6 +989,10 @@ common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, in #endif CASE_PROTO_TRACE_INT4(UINT4, UInt4x2); CASE_PROTO_TRACE_INT4(INT4, Int4x2); + +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_PROTO_TRACE_FLOAT4(FLOAT4E2M1, Float4E2M1x2); +#endif default: return common::Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -1322,6 +1383,11 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); + +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_PROTO(FLOAT4E2M1, Float4E2M1x2); +#endif + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_STRING: ORT_RETURN_IF_ERROR(UnpackTensor(tensor_proto, raw_data, raw_data_len, static_cast(preallocated), @@ -1402,6 +1468,11 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { #endif CASE_TYPE(UINT4) CASE_TYPE(INT4) + +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_TYPE(FLOAT4E2M1) +#endif + default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -1957,11 +2028,11 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T break; \ } -#define CASE_UNPACK_INT4(TYPE, ELEMENT_TYPE, DATA_SIZE) \ +#define CASE_UNPACK_4BIT_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PAIR_FUN) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ size_t element_count = static_cast(tensor_shape.Size()); \ - size_t packed_element_count = ELEMENT_TYPE::CalcNumInt4Pairs(element_count); \ + size_t packed_element_count = ELEMENT_TYPE::CALC_PAIR_FUN(element_count); \ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ return onnxruntime::utils::UnpackTensor(initializer, \ initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ @@ -2004,8 +2075,13 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size); CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size); #endif - CASE_UNPACK_INT4(INT4, Int4x2, int32_data_size); - CASE_UNPACK_INT4(UINT4, UInt4x2, int32_data_size); + CASE_UNPACK_4BIT_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_4BIT_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs); + +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_UNPACK_4BIT_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs); +#endif + default: break; } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 6b5c404e26b7f..3a23093a5445b 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -229,6 +229,13 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; +} +#endif + int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); #ifdef ENABLE_TRAINING diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 9123c0bd76ec7..5fe700e1711a4 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 0b56cac1038e4..bef0559d967d0 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -13,6 +13,7 @@ #include "core/common/status.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/float4.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fast_divmod.h" @@ -99,6 +100,15 @@ class ToCudaType { #endif +#if defined(ENABLE_FP4) && !defined(DISABLE_FLOAT4_TYPES) +// ENABLE_FP4 is only set if CUDA SDK version is >= 12.8 +template <> +class ToCudaType { + public: + typedef Float4E2M1x2::PackedCudaType MappedType; +}; +#endif + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index e036c7764d041..7c1b6e8f393cf 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1392,22 +1392,22 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); // Opset 19 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, BFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, bool, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, float, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, double, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, MLFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, BFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, int64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, bool, Cast); #if !defined(DISABLE_FLOAT8_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, Cast); #endif class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear); @@ -1491,6 +1491,27 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, bool, Cast); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); +#endif + +#if !defined(DISABLE_FLOAT4_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast); +#endif #endif @@ -2384,23 +2405,23 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - // Opset 19 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Opset 19-20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, @@ -2484,6 +2505,26 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + BuildKernelCreateInfo, +#endif #endif }; diff --git a/onnxruntime/core/providers/cuda/cuda_type_conversion.h b/onnxruntime/core/providers/cuda/cuda_type_conversion.h index f118bc9c69bcc..5d6f4f845a9f4 100644 --- a/onnxruntime/core/providers/cuda/cuda_type_conversion.h +++ b/onnxruntime/core/providers/cuda/cuda_type_conversion.h @@ -8,11 +8,15 @@ #if defined(ENABLE_FP8) && !defined(DISABLE_FLOAT8_TYPES) #include #endif +#if defined(ENABLE_FP4) && !defined(DISABLE_FLOAT4_TYPES) +#include +#endif #include #include #include "core/framework/int4.h" #include "core/framework/float8.h" #include "core/framework/float16.h" +#include "core/framework/float4.h" namespace onnxruntime { namespace cuda { @@ -88,5 +92,12 @@ struct OrtToCudaType { }; #endif +#if defined(ENABLE_FP4) && !defined(DISABLE_FLOAT4_TYPES) +template <> +struct OrtToCudaType { + using type = Float4E2M1x2::PackedCudaType; +}; +#endif + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cc b/onnxruntime/core/providers/cuda/tensor/cast_op.cc index 821695bbbd42f..8f5c9202c1dba 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cc @@ -30,6 +30,10 @@ const std::vector& CastOpTypeConstraints() { #if !defined(DISABLE_FLOAT8_TYPES) , DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType() +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + , + DataTypeImpl::GetTensorType() #endif }; return types; @@ -66,10 +70,30 @@ const std::vector& CastOpTypeConstraints() { .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", CastOpTypeConstraints()), \ + Cast); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 21, 22, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", CastOpTypeConstraints()), \ + Cast); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 19, \ + 23, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -116,6 +140,21 @@ const std::vector& CastOpTypeConstraints() { #endif +#if !defined(DISABLE_FLOAT4_TYPES) + +#define CASE_BYTE_PACKED(TP_TYPE, SrcT, DstT) \ + case TP_TYPE: \ + if (count > 0) { \ + return cast_helper_impl::CudaCastPairwise( \ + Stream(context), \ + X->Data(), \ + Y->MutableData(), \ + count); \ + } \ + break; + +#endif + template Status Cast::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaSrcT; @@ -149,87 +188,85 @@ Status Cast::ComputeInternal(OpKernelContext* context) const { case TensorProto_DataType_UNDEFINED: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected 'to' argument value: ", to_); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Unimplemented 'to' argument value: ", to_); } return Status::OK(); } +template <> +Status Cast::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaSrcT; + const Tensor* X = context->Input(0); + const TensorShape& shape = X->Shape(); + Tensor* Y = context->Output(0, shape); + const auto* x_data = reinterpret_cast(X->Data()); + size_t count = shape.Size(); + switch (to_) { + CASE(TensorProto_DataType_FLOAT16, MLFloat16) + CASE(TensorProto_DataType_BFLOAT16, BFloat16) + CASE(TensorProto_DataType_FLOAT, float) + CASE(TensorProto_DataType_DOUBLE, double) + CASE(TensorProto_DataType_INT8, int8_t) + CASE(TensorProto_DataType_INT16, int16_t) + CASE(TensorProto_DataType_INT32, int32_t) + CASE(TensorProto_DataType_INT64, int64_t) + CASE(TensorProto_DataType_UINT8, uint8_t) + CASE(TensorProto_DataType_UINT16, uint16_t) + CASE(TensorProto_DataType_UINT32, uint32_t) + CASE(TensorProto_DataType_UINT64, uint64_t) + CASE(TensorProto_DataType_BOOL, bool) #if !defined(DISABLE_FLOAT8_TYPES) - -#define COMPUTE_INTERNAL_FL16_32(FLOAT_TYPE) \ - template <> \ - Status Cast::ComputeInternal(OpKernelContext* context) const { \ - typedef typename ToCudaType::MappedType CudaSrcT; \ - const Tensor* X = context->Input(0); \ - const TensorShape& shape = X->Shape(); \ - Tensor* Y = context->Output(0, shape); \ - const auto* x_data = reinterpret_cast(X->Data()); \ - size_t count = shape.Size(); \ - switch (to_) { \ - CASE(TensorProto_DataType_FLOAT16, MLFloat16) \ - CASE(TensorProto_DataType_BFLOAT16, BFloat16) \ - CASE(TensorProto_DataType_FLOAT, float) \ - CASE(TensorProto_DataType_DOUBLE, double) \ - CASE(TensorProto_DataType_INT8, int8_t) \ - CASE(TensorProto_DataType_INT16, int16_t) \ - CASE(TensorProto_DataType_INT32, int32_t) \ - CASE(TensorProto_DataType_INT64, int64_t) \ - CASE(TensorProto_DataType_UINT8, uint8_t) \ - CASE(TensorProto_DataType_UINT16, uint16_t) \ - CASE(TensorProto_DataType_UINT32, uint32_t) \ - CASE(TensorProto_DataType_UINT64, uint64_t) \ - CASE(TensorProto_DataType_BOOL, bool) \ - CASE_SAT(TensorProto_DataType_FLOAT8E4M3FN, Float8E4M3FN) \ - CASE_SAT(TensorProto_DataType_FLOAT8E5M2, Float8E5M2) \ - case TensorProto_DataType_STRING: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet."); \ - case TensorProto_DataType_UNDEFINED: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); \ - default: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected 'to' argument value: ", to_); \ - } \ - return Status::OK(); \ - } - -#else - -#define COMPUTE_INTERNAL_FL16_32(FLOAT_TYPE) \ - template <> \ - Status Cast::ComputeInternal(OpKernelContext* context) const { \ - typedef typename ToCudaType::MappedType CudaSrcT; \ - const Tensor* X = context->Input(0); \ - const TensorShape& shape = X->Shape(); \ - Tensor* Y = context->Output(0, shape); \ - const auto* x_data = reinterpret_cast(X->Data()); \ - size_t count = shape.Size(); \ - switch (to_) { \ - CASE(TensorProto_DataType_FLOAT16, MLFloat16) \ - CASE(TensorProto_DataType_BFLOAT16, BFloat16) \ - CASE(TensorProto_DataType_FLOAT, float) \ - CASE(TensorProto_DataType_DOUBLE, double) \ - CASE(TensorProto_DataType_INT8, int8_t) \ - CASE(TensorProto_DataType_INT16, int16_t) \ - CASE(TensorProto_DataType_INT32, int32_t) \ - CASE(TensorProto_DataType_INT64, int64_t) \ - CASE(TensorProto_DataType_UINT8, uint8_t) \ - CASE(TensorProto_DataType_UINT16, uint16_t) \ - CASE(TensorProto_DataType_UINT32, uint32_t) \ - CASE(TensorProto_DataType_UINT64, uint64_t) \ - CASE(TensorProto_DataType_BOOL, bool) \ - case TensorProto_DataType_STRING: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet."); \ - case TensorProto_DataType_UNDEFINED: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); \ - default: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected 'to' argument value: ", to_); \ - } \ - return Status::OK(); \ + CASE_SAT(TensorProto_DataType_FLOAT8E4M3FN, Float8E4M3FN) + CASE_SAT(TensorProto_DataType_FLOAT8E5M2, Float8E5M2) +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_BYTE_PACKED(TensorProto_DataType_FLOAT4E2M1, float, Float4E2M1x2) +#endif + case TensorProto_DataType_STRING: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet."); + case TensorProto_DataType_UNDEFINED: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Unimplemented 'to' argument value: ", to_); } + return Status::OK(); +} +template <> +Status Cast::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaSrcT; + const Tensor* X = context->Input(0); + const TensorShape& shape = X->Shape(); + Tensor* Y = context->Output(0, shape); + const auto* x_data = reinterpret_cast(X->Data()); + size_t count = shape.Size(); + switch (to_) { + CASE(TensorProto_DataType_FLOAT16, MLFloat16) + CASE(TensorProto_DataType_BFLOAT16, BFloat16) + CASE(TensorProto_DataType_FLOAT, float) + CASE(TensorProto_DataType_DOUBLE, double) + CASE(TensorProto_DataType_INT8, int8_t) + CASE(TensorProto_DataType_INT16, int16_t) + CASE(TensorProto_DataType_INT32, int32_t) + CASE(TensorProto_DataType_INT64, int64_t) + CASE(TensorProto_DataType_UINT8, uint8_t) + CASE(TensorProto_DataType_UINT16, uint16_t) + CASE(TensorProto_DataType_UINT32, uint32_t) + CASE(TensorProto_DataType_UINT64, uint64_t) + CASE(TensorProto_DataType_BOOL, bool) +#if !defined(DISABLE_FLOAT8_TYPES) + CASE_SAT(TensorProto_DataType_FLOAT8E4M3FN, Float8E4M3FN) + CASE_SAT(TensorProto_DataType_FLOAT8E5M2, Float8E5M2) #endif - -COMPUTE_INTERNAL_FL16_32(float) -COMPUTE_INTERNAL_FL16_32(MLFloat16) + case TensorProto_DataType_STRING: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet."); + case TensorProto_DataType_UNDEFINED: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Unimplemented 'to' argument value: ", to_); + } + return Status::OK(); +} // TODO: enable BFLOAT16 in another PR. /* @@ -240,47 +277,48 @@ COMPUTE_INTERNAL_FL16_32(BFloat16) #if !defined(DISABLE_FLOAT8_TYPES) -#define COMPUTE_INTERNAL_FL8(FLOAT_TYPE) \ - template <> \ - Status Cast::ComputeInternal(OpKernelContext* context) const { \ - typedef typename ToCudaType::MappedType CudaSrcT; \ - const Tensor* X = context->Input(0); \ - const TensorShape& shape = X->Shape(); \ - Tensor* Y = context->Output(0, shape); \ - const auto* x_data = reinterpret_cast(X->Data()); \ - size_t count = shape.Size(); \ - switch (to_) { \ - case TensorProto_DataType_FLOAT16: \ - if (count > 0) { \ - Impl_Cast( \ - Stream(context), \ - x_data, \ - reinterpret_cast(Y->MutableData()), \ - count); \ - } \ - break; \ - case TensorProto_DataType_BFLOAT16: \ - if (count > 0) { \ - Impl_Cast( \ - Stream(context), \ - x_data, \ - reinterpret_cast(Y->MutableData()), \ - count); \ - } \ - break; \ - case TensorProto_DataType_FLOAT: \ - if (count > 0) { \ - Impl_Cast( \ - Stream(context), \ - x_data, \ - reinterpret_cast(Y->MutableData()), \ - count); \ - } \ - break; \ - default: \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected 'to' argument value: ", to_); \ - } \ - return Status::OK(); \ +#define COMPUTE_INTERNAL_FL8(FLOAT8_TYPE) \ + template <> \ + Status Cast::ComputeInternal(OpKernelContext* context) const { \ + typedef typename ToCudaType::MappedType CudaSrcT; \ + const Tensor* X = context->Input(0); \ + const TensorShape& shape = X->Shape(); \ + Tensor* Y = context->Output(0, shape); \ + const auto* x_data = reinterpret_cast(X->Data()); \ + size_t count = shape.Size(); \ + switch (to_) { \ + case TensorProto_DataType_FLOAT16: \ + if (count > 0) { \ + Impl_Cast( \ + Stream(context), \ + x_data, \ + reinterpret_cast(Y->MutableData()), \ + count); \ + } \ + break; \ + case TensorProto_DataType_BFLOAT16: \ + if (count > 0) { \ + Impl_Cast( \ + Stream(context), \ + x_data, \ + reinterpret_cast(Y->MutableData()), \ + count); \ + } \ + break; \ + case TensorProto_DataType_FLOAT: \ + if (count > 0) { \ + Impl_Cast( \ + Stream(context), \ + x_data, \ + reinterpret_cast(Y->MutableData()), \ + count); \ + } \ + break; \ + default: \ + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, \ + "Unimplemented 'to' argument value: ", to_); \ + } \ + return Status::OK(); \ } COMPUTE_INTERNAL_FL8(Float8E4M3FN) @@ -288,6 +326,29 @@ COMPUTE_INTERNAL_FL8(Float8E5M2) #endif +#if !defined(DISABLE_FLOAT4_TYPES) + +template <> +Status Cast::ComputeInternal(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + const TensorShape& shape = X->Shape(); + Tensor* Y = context->Output(0, shape); + size_t count = shape.Size(); + + switch (to_) { + CASE_BYTE_PACKED(TensorProto_DataType_FLOAT, Float4E2M1x2, float); + case TensorProto_DataType_STRING: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet."); + case TensorProto_DataType_UNDEFINED: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType"); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Unimplemented 'to' argument value: ", to_); + } + return Status::OK(); +} + +#endif + #define SPECIALIZE_IMPL(T) \ REGISTER_KERNEL_TYPED(T) \ template Status Cast::ComputeInternal(OpKernelContext* context) const; @@ -306,11 +367,21 @@ SPECIALIZE_IMPL(uint64_t) SPECIALIZE_IMPL(bool) SPECIALIZE_IMPL(BFloat16) -#define REGISTER_KERNEL_TYPED_19(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ +#define REGISTER_KERNEL_TYPED_19_TO_22(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 19, 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", CastOpTypeConstraints()), \ + Cast); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 19, \ + 21, 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -318,15 +389,33 @@ SPECIALIZE_IMPL(BFloat16) .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); +#define REGISTER_KERNEL_TYPED_23(T, OutputTypeConstraints) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 23, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", OutputTypeConstraints), \ + Cast); + #if !defined(DISABLE_FLOAT8_TYPES) -#define SPECIALIZE_IMPL_19(T) \ - REGISTER_KERNEL_TYPED_19(T) \ +#define SPECIALIZE_IMPL_19_TO_23(T) \ + REGISTER_KERNEL_TYPED_19_TO_22(T) \ + REGISTER_KERNEL_TYPED_23(T, CastOpTypeConstraints()) \ template Status Cast::ComputeInternal(OpKernelContext* context) const; -SPECIALIZE_IMPL_19(Float8E4M3FN) -SPECIALIZE_IMPL_19(Float8E5M2) +SPECIALIZE_IMPL_19_TO_23(Float8E4M3FN) +SPECIALIZE_IMPL_19_TO_23(Float8E5M2) + +#endif +#if !defined(DISABLE_FLOAT4_TYPES) +REGISTER_KERNEL_TYPED_23(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) +template Status Cast::ComputeInternal(OpKernelContext* context) const; #endif } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu index c98eabecdabab..a8cd6caaa5d5f 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu @@ -9,9 +9,13 @@ #include "cuda_fp8.h" #endif +#include "core/framework/float4.h" + +#include "cast_op.h" + namespace onnxruntime { namespace cuda { - +namespace cast_helper_impl { template struct CastStd; @@ -143,6 +147,78 @@ struct CastSat { #endif // DISABLE_FLOAT8_TYPES +#if !defined(DISABLE_FLOAT4_TYPES) + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + +template <> +struct CastStd { + __device__ __forceinline__ float2 operator()(Float4E2M1x2 v) const { + return v.ToCudaFloat2(); + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ Float4E2M1x2 operator()(float2 v) const { + return Float4E2M1x2(v); + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ float operator()(Float4E2M1x2 v) const { + return v.ToCudaFloat2().x; + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ Float4E2M1x2 operator()(float v) const { + return Float4E2M1x2(v, 0); + } +}; + +#else +template <> +struct CastStd { + __device__ __forceinline__ float2 operator()(Float4E2M1x2 v) const { + auto float_pair = v.ToFloat2(); + + float2 res; + res.x = float_pair.first; + res.y = float_pair.second; + + return res; + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ Float4E2M1x2 operator()(float2 v) const { + return Float4E2M1x2(v.x, v.y); + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ float operator()(Float4E2M1x2 v) const { + auto float_pair = v.ToFloat2(); + return float_pair.first; + } +}; + +template <> +struct CastStd { + __device__ __forceinline__ Float4E2M1x2 operator()(float v) const { + return Float4E2M1x2(v, 0); + } +}; + +#endif + +#endif // DISABLE_FLOAT4_TYPES + template __global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) { CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; @@ -157,19 +233,123 @@ __global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastS } template -Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t num_of_element) { - if (num_of_element <= 0) +Status CudaCastStd(cudaStream_t stream, const InT* input, OutT* output, size_t num_of_elements) { + if (num_of_elements <= 0) return Status::OK(); - int blocksPerGrid = static_cast(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + int blocksPerGrid = static_cast(CeilDiv(num_of_elements, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); CastKernelStd<<>>( input, output, - static_cast(num_of_element), + static_cast(num_of_elements), CastStd()); return Status::OK(); } +#if !defined(DISABLE_FLOAT4_TYPES) + +template +__global__ void CudaCastPairwiseKernel(const InPairType* input, OutPairType* output, + CUDA_LONG pair_count, + CastStd pair_caster, + CastStd singleton_caster) { + CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < pair_count) { + output[id] = pair_caster(input[id]); + id += NumThreadsPerBlock; + } + if constexpr (is_odd) { + // If odd, one thread picks up the singleton element + if (id == pair_count) { + *reinterpret_cast(&output[id]) = singleton_caster(*reinterpret_cast(&input[id])); + } + } + } +} + +template +Status CudaCastPairwise(cudaStream_t stream, const InT* input, OutT* output, size_t num_of_elements) { + // There is no generic implementation - specialized implementation for the packed type(s) follow + return Status::OK(); +} + +template <> +Status CudaCastPairwise(cudaStream_t stream, const Float4E2M1x2* input, float* output, size_t num_of_elements) { + if (num_of_elements <= 0) + return Status::OK(); + + bool is_odd = (num_of_elements & 0x01) != 0; + + int pair_count = static_cast(num_of_elements / 2); + + int blocksPerGrid = static_cast(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + + if (pair_count == 0) { + blocksPerGrid = 1; + } + + if (is_odd) { + CudaCastPairwiseKernel + <<>>( + input, reinterpret_cast(output), pair_count, + CastStd(), + CastStd()); + } else { + CudaCastPairwiseKernel + <<>>( + input, reinterpret_cast(output), pair_count, + CastStd(), + CastStd()); + } + + return Status::OK(); +} + +template <> +Status CudaCastPairwise(cudaStream_t stream, const float* input, Float4E2M1x2* output, size_t num_of_elements) { + if (num_of_elements <= 0) + return Status::OK(); + + bool is_odd = (num_of_elements & 0x01) != 0; + + int pair_count = static_cast(num_of_elements / 2); + + int blocksPerGrid = static_cast(CeilDiv(pair_count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + + if (pair_count == 0) { + blocksPerGrid = 1; + } + + if (is_odd) { + CudaCastPairwiseKernel + <<>>( + reinterpret_cast(input), output, pair_count, + CastStd(), + CastStd()); + } else { + CudaCastPairwiseKernel + <<>>( + reinterpret_cast(input), output, pair_count, + CastStd(), + CastStd()); + } + + return Status::OK(); +} + +template Status CudaCastPairwise(cudaStream_t stream, const Float4E2M1x2* input, float* output, size_t num_of_element); +template Status CudaCastPairwise(cudaStream_t stream, const float* input, Float4E2M1x2* output, size_t num_of_element); + +#endif + #if !defined(DISABLE_FLOAT8_TYPES) template @@ -214,5 +394,6 @@ template Status CudaCastSat(cudaStream_t stream, const half* i #endif +} // namespace cast_helper_impl } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.h b/onnxruntime/core/providers/cuda/tensor/cast_op.h index f4304922b8f37..943874f29ccba 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.h +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.h @@ -36,5 +36,10 @@ class Cast final : public CudaKernel { bool saturate_; }; +namespace cast_helper_impl { +template +Status CudaCastPairwise(cudaStream_t stream, const InT* input, OutT* output, size_t num_of_elements); +} // namespace cast_helper_impl + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index a7fd83f10fe18..cb14b1cdbb645 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -30,6 +30,7 @@ #include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/framework/int4.h" +#include "core/framework/float4.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" #include "core/common/path_string.h" @@ -77,6 +78,7 @@ enum TensorProto_DataType : int { TensorProto_DataType_FLOAT8E5M2FNUZ = 20, TensorProto_DataType_UINT4 = 21, TensorProto_DataType_INT4 = 22, + TensorProto_DataType_FLOAT4E2M1 = 23, }; enum TensorProto_DataLocation : int { @@ -95,7 +97,8 @@ enum Version : int { IR_VERSION_2020_5_8 = 7, IR_VERSION_2021_7_31 = 8, IR_VERSION_2023_5_5 = 9, - IR_VERSION = 10 + IR_VERSION_2024_3_25 = 10, + IR_VERSION = 11 }; enum OperatorStatus : int { @@ -394,6 +397,12 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ; } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; } +#endif + template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d690cf31072d2..0e5df0026d2c0 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -173,6 +173,11 @@ MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataType template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt4x2(); } +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Float4E2M1x2(); } +#endif + template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_string(); } MLDataType DataTypeImpl::GetTensorTypeFromOnnxType(int onnx_type) { return Provider_GetHost()->DataTypeImpl__GetTensorTypeFromOnnxType(onnx_type); } @@ -218,6 +223,11 @@ MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->Da template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt4x2(); } +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Float4E2M1x2(); } +#endif + #if !defined(DISABLE_SPARSE_TENSORS) template <> MLDataType DataTypeImpl::GetSparseTensorType() { return Provider_GetHost()->DataTypeImpl__GetSparseTensorType_bool(); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 9a0bcb53c9ad7..7e56886b1f558 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -760,6 +760,9 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetType_Float8E4M3FNUZ() = 0; virtual MLDataType DataTypeImpl__GetType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() = 0; +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + virtual MLDataType DataTypeImpl__GetType_Float4E2M1x2() = 0; #endif virtual MLDataType DataTypeImpl__GetType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetType_UInt4x2() = 0; @@ -784,6 +787,10 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() = 0; #endif +#if !defined(DISABLE_FLOAT4_TYPES) + virtual MLDataType DataTypeImpl__GetTensorType_Float4E2M1x2() = 0; +#endif + virtual MLDataType DataTypeImpl__GetTensorType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_UInt4x2() = 0; @@ -826,6 +833,8 @@ struct ProviderHost { virtual const std::vector& DataTypeImpl__AllTensorTypes() = 0; virtual const std::vector& DataTypeImpl__AllTensorTypesIRv4() = 0; virtual const std::vector& DataTypeImpl__AllTensorTypesIRv9() = 0; + virtual const std::vector& DataTypeImpl__AllTensorTypesIRv10() = 0; + virtual const std::vector& DataTypeImpl__AllTensorTypesIRv11() = 0; virtual const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() = 0; @@ -1224,6 +1233,9 @@ struct ProviderHost { virtual Float8E4M3FNUZ* Tensor__MutableData_Float8E4M3FNUZ(Tensor* p) = 0; virtual Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) = 0; virtual Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) = 0; +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + virtual Float4E2M1x2* Tensor__MutableData_Float4E2M1x2(Tensor* p) = 0; #endif virtual Int4x2* Tensor__MutableData_Int4x2(Tensor* p) = 0; virtual UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) = 0; @@ -1247,6 +1259,9 @@ struct ProviderHost { virtual const Float8E4M3FNUZ* Tensor__Data_Float8E4M3FNUZ(const Tensor* p) = 0; virtual const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) = 0; virtual const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) = 0; +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + virtual const Float4E2M1x2* Tensor__Data_Float4E2M1x2(const Tensor* p) = 0; #endif virtual const Int4x2* Tensor__Data_Int4x2(const Tensor* p) = 0; virtual const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) = 0; @@ -1280,6 +1295,9 @@ struct ProviderHost { virtual bool Tensor__IsDataType_Float8E4M3FNUZ(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept = 0; +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + virtual bool Tensor__IsDataType_Float4E2M1x2(const Tensor* p) noexcept = 0; #endif virtual bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 19b4636c3766d..1ab32e649ed40 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -776,6 +776,8 @@ class DataTypeImpl final { static const std::vector& AllTensorTypes() { return g_host->DataTypeImpl__AllTensorTypes(); } static const std::vector& AllTensorTypesIRv4() { return g_host->DataTypeImpl__AllTensorTypesIRv4(); } static const std::vector& AllTensorTypesIRv9() { return g_host->DataTypeImpl__AllTensorTypesIRv9(); } + static const std::vector& AllTensorTypesIRv10() { return g_host->DataTypeImpl__AllTensorTypesIRv10(); } + static const std::vector& AllTensorTypesIRv11() { return g_host->DataTypeImpl__AllTensorTypesIRv11(); } static const std::vector& AllIEEEFloatTensorTypes() { return g_host->DataTypeImpl__AllIEEEFloatTensorTypes(); } @@ -1513,6 +1515,11 @@ template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_Float8E5M2FNUZ(this); } #endif +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_Float4E2M1x2(this); } +#endif + template <> inline bool* Tensor::MutableData() { return g_host->Tensor__MutableData_bool(this); } template <> @@ -1555,6 +1562,11 @@ template <> inline Float8E5M2FNUZ* Tensor::MutableData() { return g_host->Tensor__MutableData_Float8E5M2FNUZ(this); } #endif +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +inline Float4E2M1x2* Tensor::MutableData() { return g_host->Tensor__MutableData_Float4E2M1x2(this); } +#endif + template <> inline const bool* Tensor::Data() const { return g_host->Tensor__Data_bool(this); } template <> @@ -1597,6 +1609,11 @@ template <> inline const Float8E5M2FNUZ* Tensor::Data() const { return g_host->Tensor__Data_Float8E5M2FNUZ(this); } #endif +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +inline const Float4E2M1x2* Tensor::Data() const { return g_host->Tensor__Data_Float4E2M1x2(this); } +#endif + // SparseTensor #if !defined(DISABLE_SPARSE_TENSORS) struct SparseTensor final { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 41cf8be1d1412..7e28a15535dbb 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -948,6 +948,11 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_Float8E5M2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_Float8E5M2FNUZ() override { return DataTypeImpl::GetType(); } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + MLDataType DataTypeImpl__GetType_Float4E2M1x2() override { return DataTypeImpl::GetType(); } +#endif + MLDataType DataTypeImpl__GetType_Int4x2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_UInt4x2() override { return DataTypeImpl::GetType(); } @@ -972,6 +977,11 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_Float8E5M2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetTensorType(); } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + MLDataType DataTypeImpl__GetTensorType_Float4E2M1x2() override { return DataTypeImpl::GetTensorType(); } +#endif + MLDataType DataTypeImpl__GetTensorType_Int4x2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_UInt4x2() override { return DataTypeImpl::GetTensorType(); } @@ -990,12 +1000,14 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetSparseTensorType_string() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_BFloat16() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_MLFloat16() override { return DataTypeImpl::GetSparseTensorType(); } + #if !defined(DISABLE_FLOAT8_TYPES) MLDataType DataTypeImpl__GetSparseTensorType_Float8E4M3FN() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E4M3FNUZ() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2() override { return DataTypeImpl::GetSparseTensorType(); } MLDataType DataTypeImpl__GetSparseTensorType_Float8E5M2FNUZ() override { return DataTypeImpl::GetSparseTensorType(); } #endif + #endif const char* DataTypeImpl__ToString(MLDataType type) override { return DataTypeImpl::ToString(type); } @@ -1014,6 +1026,8 @@ struct ProviderHostImpl : ProviderHost { const std::vector& DataTypeImpl__AllTensorTypes() override { return DataTypeImpl::AllTensorTypes(); } const std::vector& DataTypeImpl__AllTensorTypesIRv4() override { return DataTypeImpl::AllTensorTypesIRv4(); } const std::vector& DataTypeImpl__AllTensorTypesIRv9() override { return DataTypeImpl::AllTensorTypesIRv9(); } + const std::vector& DataTypeImpl__AllTensorTypesIRv10() override { return DataTypeImpl::AllTensorTypesIRv10(); } + const std::vector& DataTypeImpl__AllTensorTypesIRv11() override { return DataTypeImpl::AllTensorTypesIRv11(); } const std::vector& DataTypeImpl__AllIEEEFloatTensorTypes() override { return DataTypeImpl::AllIEEEFloatTensorTypes(); } @@ -1588,6 +1602,11 @@ struct ProviderHostImpl : ProviderHost { Float8E5M2* Tensor__MutableData_Float8E5M2(Tensor* p) override { return p->MutableData(); } Float8E5M2FNUZ* Tensor__MutableData_Float8E5M2FNUZ(Tensor* p) override { return p->MutableData(); } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + Float4E2M1x2* Tensor__MutableData_Float4E2M1x2(Tensor* p) override { return p->MutableData(); } +#endif + Int4x2* Tensor__MutableData_Int4x2(Tensor* p) override { return p->MutableData(); } UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) override { return p->MutableData(); } @@ -1611,6 +1630,11 @@ struct ProviderHostImpl : ProviderHost { const Float8E5M2* Tensor__Data_Float8E5M2(const Tensor* p) override { return p->Data(); } const Float8E5M2FNUZ* Tensor__Data_Float8E5M2FNUZ(const Tensor* p) override { return p->Data(); } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + const Float4E2M1x2* Tensor__Data_Float4E2M1x2(const Tensor* p) override { return p->Data(); } +#endif + const Int4x2* Tensor__Data_Int4x2(const Tensor* p) override { return p->Data(); } const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) override { return p->Data(); } @@ -1642,6 +1666,11 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_Float8E5M2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_Float8E5M2FNUZ(const Tensor* p) noexcept override { return p->IsDataType(); } #endif + +#if !defined(DISABLE_FLOAT4_TYPES) + bool Tensor__IsDataType_Float4E2M1x2(const Tensor* p) noexcept override { return p->IsDataType(); } +#endif + bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept override { return p->IsDataType(); } diff --git a/onnxruntime/test/contrib_ops/gemm_float8_test.cc b/onnxruntime/test/contrib_ops/gemm_float8_test.cc index c022736075cde..36b2612b0d0d6 100644 --- a/onnxruntime/test/contrib_ops/gemm_float8_test.cc +++ b/onnxruntime/test/contrib_ops/gemm_float8_test.cc @@ -91,7 +91,7 @@ template void TestGemmFloat8WithFloat8(int64_t dtype) { int min_cuda_architecture = 11080; if (!HasCudaEnvironment(min_cuda_architecture)) { - LOGS_DEFAULT(WARNING) << "Hardware NOT support Matrix Multiplication for FLOAT8"; + LOGS_DEFAULT(WARNING) << "Hardware does NOT support Matrix Multiplication for FLOAT8"; return; } OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); @@ -101,10 +101,10 @@ void TestGemmFloat8WithFloat8(int64_t dtype) { test.AddAttribute("beta", 1.0f); test.AddAttribute("activation", "NONE"); test.AddAttribute("dtype", dtype); - test.AddInput("A", {2, 4}, _TypeCvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); - test.AddInput("B", {3, 4}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); - test.AddInput("C", {2, 3}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); - test.AddOutput("Y", {2, 3}, _TypeCvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); + test.AddInput("A", {2, 4}, _TypedCvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); + test.AddInput("B", {3, 4}, _TypedCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddInput("C", {2, 3}, _TypedCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddOutput("Y", {2, 3}, _TypedCvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); diff --git a/onnxruntime/test/framework/float_4_test.cc b/onnxruntime/test/framework/float_4_test.cc new file mode 100644 index 0000000000000..03d13d99c7bc1 --- /dev/null +++ b/onnxruntime/test/framework/float_4_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(DISABLE_FLOAT4_TYPES) + +#include +#include + +#ifdef USE_CUDA +// Needed for CUDA_VERSION check in float4.h +#include +#endif + +#include "core/framework/float4.h" +#include "test/test_environment.h" +#include "test_utils.h" +#include "gtest/gtest.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace test { + +TEST(Float4_Tests, BasicFloatConversion) { + std::map, std::pair> cases{ + {std::pair(0.f, -0.f), + std::pair(0.f, -0.f)}, + {std::pair(0.5f, -0.5f), + std::pair(0.5f, -0.5f)}, + {std::pair(1.5323f, -1.932f), + std::pair(1.5f, -2.f)}, + {std::pair(2.173f, 3.5f), + std::pair(2.f, 4.f)}, + {std::pair(6.5f, -9.f), + std::pair(6.f, -6.f)}, + {std::pair(-2.5f, 2.5), + std::pair(-2.f, 2.f)}, + {std::pair(-1.25f, 1.25f), + std::pair(-1.f, 1.f)}, + {std::pair(std::numeric_limits::infinity(), + -std::numeric_limits::infinity()), + std::pair(6.f, -6.f)}}; + + for (auto& c : cases) { + auto f4e4m2_instance = Float4E2M1x2(c.first.first, c.first.second); + + // Using ToFloat2() interface + auto f_cvt_returned = f4e4m2_instance.ToFloat2(); + EXPECT_EQ(f_cvt_returned.first, c.second.first); + EXPECT_EQ(f_cvt_returned.second, c.second.second); + + // Using GetElem() interface + EXPECT_EQ(f4e4m2_instance.GetElem(0), c.second.first); + EXPECT_EQ(f4e4m2_instance.GetElem(1), c.second.second); + } + + // NaNs + auto NaNs_converted = Float4E2M1x2(std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()) + .ToFloat2(); + + EXPECT_EQ(NaNs_converted.first, 6.f); + EXPECT_EQ(NaNs_converted.second, 6.f); +} + +TEST(Float4_Tests, BitRepresentationChecks) { + // FromBits test + std::pair pair; + pair = Float4E2M1x2(0x87, Float4E2M1x2::FromBits()).ToFloat2(); + EXPECT_EQ(pair.first, 6.f); + EXPECT_EQ(pair.second, -0.f); + + pair = Float4E2M1x2(0x7F, Float4E2M1x2::FromBits()).ToFloat2(); + EXPECT_EQ(pair.first, -6.f); + EXPECT_EQ(pair.second, 6.f); + + // Bit representation test + uint8_t bits = Float4E2M1x2(-6.f, 6.f).ToBits(); + // High nibble stores the second value + 6 and low nibble stores the first value - 6 + EXPECT_EQ((bits & 0xF0) >> 4, 0x07); // +6 + EXPECT_EQ((bits & 0x0F), 0x0F); // -6 +} + +TEST(Float4_Tests, PackingAndUnpacking) { + { + // Unpack 5 FP4 (odd count) elements + std::vector packed{Float4E2M1x2(1.f, -0.5f), + Float4E2M1x2(4.f, -6.f), + Float4E2M1x2(3.f, 0.f)}; // padding 0 + std::vector unpacked(5, -1.f); + + Float4E2M1x2::UnpackFloat4E2M1ToFloat(packed.data(), unpacked.data(), 5); + EXPECT_EQ(unpacked[0], packed[0].ToFloat2().first); + EXPECT_EQ(unpacked[1], packed[0].ToFloat2().second); + EXPECT_EQ(unpacked[2], packed[1].ToFloat2().first); + EXPECT_EQ(unpacked[3], packed[1].ToFloat2().second); + EXPECT_EQ(unpacked[4], packed[2].ToFloat2().first); + } + + { + // Unpack 6 FP4 (even count) elements + std::vector packed{Float4E2M1x2(1.f, -0.5f), + Float4E2M1x2(4.f, -6.f), + Float4E2M1x2(3.f, -3.f)}; + std::vector unpacked(6, -1.f); + + Float4E2M1x2::UnpackFloat4E2M1ToFloat(packed.data(), unpacked.data(), 6); + EXPECT_EQ(unpacked[0], packed[0].ToFloat2().first); + EXPECT_EQ(unpacked[1], packed[0].ToFloat2().second); + EXPECT_EQ(unpacked[2], packed[1].ToFloat2().first); + EXPECT_EQ(unpacked[3], packed[1].ToFloat2().second); + EXPECT_EQ(unpacked[4], packed[2].ToFloat2().first); + EXPECT_EQ(unpacked[5], packed[2].ToFloat2().second); + } + + { + // Pack 5 float (odd count) elements + std::vector unpacked{1.f, -0.5f, 4.f, -6.f, 3.f, 0.f}; + std::vector packed(3); + + Float4E2M1x2::PackFloatToFloat4E2M1(unpacked.data(), packed.data(), 5); + EXPECT_EQ(Float4E2M1x2(unpacked[0], unpacked[1]), packed[0]); + EXPECT_EQ(Float4E2M1x2(unpacked[2], unpacked[3]), packed[1]); + EXPECT_EQ(Float4E2M1x2(unpacked[4], 0), packed[2]); // padding 0 + } + + { + // Pack 6 float (even count) elements + std::vector unpacked{1.f, -0.5f, 4.f, -6.f, 3.f, 8.f}; + std::vector packed(3); + + Float4E2M1x2::PackFloatToFloat4E2M1(unpacked.data(), packed.data(), 6); + EXPECT_EQ(Float4E2M1x2(unpacked[0], unpacked[1]), packed[0]); + EXPECT_EQ(Float4E2M1x2(unpacked[2], unpacked[3]), packed[1]); + EXPECT_EQ(Float4E2M1x2(unpacked[4], unpacked[5]), packed[2]); + } +} + +TEST(Float4_Tests, TestLimits) { + EXPECT_FALSE(std::numeric_limits::has_infinity); + EXPECT_FALSE(std::numeric_limits::has_quiet_NaN); + EXPECT_FALSE(std::numeric_limits::has_signaling_NaN); + + EXPECT_EQ(std::numeric_limits::min(), + Float4E2M1x2(0x22, onnxruntime::Float4E2M1x2::FromBits())); + EXPECT_EQ(std::numeric_limits::max(), + Float4E2M1x2(0x77, onnxruntime::Float4E2M1x2::FromBits())); + EXPECT_EQ(std::numeric_limits::lowest(), + Float4E2M1x2(0xFF, onnxruntime::Float4E2M1x2::FromBits())); + EXPECT_EQ(std::numeric_limits::denorm_min(), + Float4E2M1x2(0x11, onnxruntime::Float4E2M1x2::FromBits())); +} + +} // namespace test +} // namespace onnxruntime + +#endif // DISABLE_FLOAT4_TYPES diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index 0b4ec1bab192a..1fca7d383caf6 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -122,6 +122,29 @@ void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); } +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Float4E2M1x2* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_pairs = (expected_num_elements + 1) / 2; + + if (num_packed_pairs != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int4 pairs", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_pairs); + gsl::span dst_span = gsl::make_span(p_data, num_packed_pairs); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); +} +#endif + // This macro doesn't work for Float16/bool/string tensors #define DEFINE_UNPACK_TENSOR(T, Type, field_name, field_size) \ template <> \ @@ -350,6 +373,45 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) DEFINE_UNPACK_TENSOR_INT4(Int4x2, TensorProto_DataType_INT4) DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, + /*out*/ Float4E2M1x2* p_data, size_t expected_num_elems) { + if (nullptr == p_data) { + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); + if (size == 0) { + return; + } + ORT_CXX_API_THROW("p_data == nullptr, but size != 0", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + if (ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1 != tensor.data_type()) { + ORT_CXX_API_THROW("TensorProto data type is not FLOAT4", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + + size_t expected_float4_pairs = (expected_num_elems + 1) / 2; + + if (raw_data != nullptr) { + UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); + return; + } + + if (static_cast(tensor.int32_data_size()) != expected_float4_pairs) { + ORT_CXX_API_THROW("UnpackTensor: the pre-allocated size does not match the size in proto", + OrtErrorCode::ORT_FAIL); + } + + constexpr int max_value = std::numeric_limits::max(); + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { + int v = tensor.int32_data()[i]; + if (v < 0 || v > max_value) { + ORT_CXX_API_THROW( + "data overflow", OrtErrorCode::ORT_FAIL); + } + p_data[i] = Float4E2M1x2(static_cast(v), Float4E2M1x2::FromBits()); + } +} +#endif + #define CASE_PROTO_TRACE(X, Y) \ case onnx::TensorProto_DataType::TensorProto_DataType_##X: \ if (!CalcMemSizeForArrayWithAlignment(size, sizeof(Y), alignment, out)) { \ @@ -364,6 +426,15 @@ DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) } \ break; +#if !defined(DISABLE_FLOAT4_TYPES) +#define CASE_PROTO_TRACE_FLOAT4(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!CalcMemSizeForArrayWithAlignment((size + 1) / 2, 1, alignment, out)) { \ + ORT_CXX_API_THROW("Invalid TensorProto", OrtErrorCode::ORT_FAIL); \ + } \ + break; +#endif + template Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { const auto& dims = tensor_proto.dims(); @@ -396,6 +467,10 @@ Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_p CASE_PROTO_TRACE(FLOAT8E5M2, Float8E5M2); CASE_PROTO_TRACE(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); #endif +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_PROTO_TRACE_FLOAT4(FLOAT4E2M1); +#endif + CASE_PROTO_TRACE(STRING, std::string); CASE_PROTO_TRACE_INT4(UINT4); CASE_PROTO_TRACE_INT4(INT4); @@ -483,6 +558,7 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT8E4M3FNUZ) CASE_TYPE(FLOAT8E5M2) CASE_TYPE(FLOAT8E5M2FNUZ) + CASE_TYPE(FLOAT4E2M1) CASE_TYPE(UINT4) CASE_TYPE(INT4) default: @@ -548,6 +624,9 @@ Status TensorProtoToMLValue(const onnx::TensorProto& tensor_proto, const MemBuff CASE_PROTO(FLOAT8E4M3FNUZ, Float8E4M3FNUZ); CASE_PROTO(FLOAT8E5M2, Float8E5M2); CASE_PROTO(FLOAT8E5M2FNUZ, Float8E5M2FNUZ); +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + CASE_PROTO(FLOAT4E2M1, Float4E2M1x2); #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); @@ -673,6 +752,12 @@ Status MLValueToTensorProto(Ort::Value& value, onnx::TensorProto& tensor_proto) tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ; tensor_elem_bytes = 1; break; +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: + tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT4E2M1; + tensor_elem_bytes = 1; + break; #endif case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4; diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index 182ee4a9550fe..b55a43c92637c 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -10,6 +10,7 @@ #include #include "core/framework/customregistry.h" +#include "core/framework/float4.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/run_options.h" #include "core/framework/tensor.h" @@ -698,7 +699,15 @@ class BaseTester { const int64_t expected_values_count = T::CalcNumInt4Pairs(shape.Size()); ORT_ENFORCE(expected_values_count == values_count, values_count, " input values doesn't match tensor size of ", expected_values_count); - } else { + } +#if !defined(DISABLE_FLOAT4_TYPES) + else if constexpr (std::is_same_v) { + const int64_t expected_values_count = T::CalcNumFloat4Pairs(shape.Size()); + ORT_ENFORCE(expected_values_count == values_count, values_count, + " input values doesn't match tensor size of ", expected_values_count); + } +#endif + else { ORT_ENFORCE(shape.Size() == values_count, values_count, " input values doesn't match tensor size of ", shape.Size()); } diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index ff5895623fc9b..19ac5e06cddd8 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -8,6 +8,7 @@ #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" #include "core/framework/int4.h" +#include "core/framework/float4.h" #include "test/framework/test_utils.h" #include "test/providers/provider_test_utils.h" @@ -179,12 +180,52 @@ struct TensorCheck { } }; +#if !defined(DISABLE_FLOAT4_TYPES) +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& provider_type) const { + // TODO(hasesh): Implement sorting and other requests the `params` may be holding + ORT_UNUSED_PARAMETER(params); + + if (actual.Shape() != expected.Shape()) { + ORT_THROW("Shape mismatch"); + } + + const auto size = actual.Shape().Size(); + + const Float4E2M1x2* expected_data = expected.Data(); + const Float4E2M1x2* actual_data = actual.Data(); + + // TODO(hasesh): Add separate per-EP tolerance for Float4 types + // For now, using float tolerance is fine + auto tolerance_params = get_tolerance_params(params, provider_type); + + for (int64_t i = 0; i < size; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + + float expected_f = expected_data[r].GetElem(c); + float actual_f = actual_data[r].GetElem(c); + + if (std::isnan(expected_f)) { + EXPECT_TRUE(std::isnan(actual_f)) << "Expected NaN. i:" << i; + } else if (std::isinf(expected_f)) { // Test infinity for equality + EXPECT_EQ(expected_f, actual_f) << "Expected infinity. i:" << i; + } else { + float tolerance = get_tolerance(tolerance_params, expected_f); + EXPECT_NEAR(expected_f, actual_f, tolerance) << "i:" << i; + } + } + } +}; +#endif + template <> struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, const std::string& /*provider_type*/) const { ORT_UNUSED_PARAMETER(params); - Tensor expected_sorted, actual_sorted; const Int4x2* cur_expected; const Int4x2* cur_actual; const auto size = actual.Shape().Size(); @@ -204,7 +245,6 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, const std::string& /*provider_type*/) const { ORT_UNUSED_PARAMETER(params); - Tensor expected_sorted, actual_sorted; const UInt4x2* cur_expected; const UInt4x2* cur_actual; const auto size = actual.Shape().Size(); @@ -498,8 +538,10 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor int8_t, int16_t, int32_t, int64_t, std::string, Int4x2, UInt4x2, #if !defined(DISABLE_FLOAT8_TYPES) - Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + Float4E2M1x2, #endif MLFloat16, BFloat16> t_disp(actual.GetElementType()); diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 68d4f3559504a..cab2a93be9022 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -58,7 +58,8 @@ void TestCastOp(gsl::span input, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& expected_failure_string = "", int opset = 21, - Saturate saturate = Saturate::None) { + Saturate saturate = Saturate::None, + bool cuda_only = false) { OpTester test("Cast", opset); test.AddAttribute("to", utils::ToTensorProtoElementType()); test.AddInput("input", dimensions, input.data(), input.size()); @@ -74,6 +75,17 @@ void TestCastOp(gsl::span input, excluded_provider_types.insert(kCudaExecutionProvider); } + if (cuda_only && (excluded_provider_types.count(kCudaExecutionProvider) > 0)) { + return; + } + + std::vector> execution_providers; + if (cuda_only) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(expect_result, expected_failure_string, {}, nullptr, &execution_providers); + return; + } + test.Run(expect_result, expected_failure_string, excluded_provider_types); } @@ -1459,5 +1471,109 @@ TEST(CastOpTest, Float8E4M3FNToUInt4x2) { #endif +#if !defined(DISABLE_FLOAT4_TYPES) && defined(USE_CUDA) + +template +void CastOpTestFloatFloat4(std::vector shape, + std::vector float_data, + bool is_fp4_input = false) { + int num_pairs = static_cast(float_data.size()) / 2; + int num_fp4_elements = static_cast((float_data.size() + 1) / 2); + bool is_odd_count = (float_data.size() % 2 != 0); + + std::vector fp4_data; + fp4_data.reserve(num_fp4_elements); + + for (size_t i = 0; i < num_pairs; ++i) { + fp4_data.emplace_back(F4(float_data[i * 2], float_data[i * 2 + 1])); + } + + if (is_odd_count) { + fp4_data.emplace_back(F4(float_data[num_pairs * 2], 0)); // Padding zero + } + + if (!is_fp4_input) { + TestCastOp(gsl::make_span(float_data), gsl::make_span(fp4_data), shape, + OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + + } else { + std::vector casted_back_float; + for (size_t i = 0; i < num_pairs; ++i) { + auto pair = fp4_data[i].ToFloat2(); + casted_back_float.push_back(pair.first); + casted_back_float.push_back(pair.second); + } + + if (is_odd_count) { + casted_back_float.push_back(fp4_data[num_pairs].ToFloat2().first); + } + + TestCastOp(gsl::make_span(fp4_data), gsl::make_span(casted_back_float), shape, + OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + } +} + +static std::vector GenerateRandomFloatVector(size_t count) { + std::vector ret; + ret.reserve(count); + for (size_t i = 0; i < count; ++i) { + int sign = (((rand() % 2) == 0) ? 1 : -1); + float random = (static_cast(rand()) / static_cast(RAND_MAX)) * 7.f; // let some values be outside the range of FP4 + ret.push_back(sign * random); + } + + return ret; +} + +TEST(CastOpTest, FloatToFloat4E2M1x2) { + // Even count test (with some special values) + CastOpTestFloatFloat4({2, 2, 2}, + {std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 7.f, -7.f, + 0.5f, -0.5f, + std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}); + + // Odd count test + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, + 0.987f, + 43.8f}); + + // Arbitrary sized tests + std::vector counts = {1, 5, 256, 512, 1024, 1025, 2048, 2049, 127, 89, 53, 42}; + + for (auto s : counts) { + CastOpTestFloatFloat4({s, 1, 1}, GenerateRandomFloatVector(s)); + } +} + +TEST(CastOpTest, Float4E2M1x2ToFloat) { + // Even count test (with some special values) + CastOpTestFloatFloat4({2, 2, 2}, + {0.5f, 7.34f, + 1.f, 1.5f, + 2.f, 3.f, + 4.f, 6.f}, + true); + + // Odd count test + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, + 0.987f, + 43.8f}, + true); + + // Arbitrary sized tests + std::vector counts = {1, 5, 256, 512, 1024, 1025, 2048, 2049, 127, 89, 53, 42}; + + for (auto s : counts) { + CastOpTestFloatFloat4({s, 1, 1}, GenerateRandomFloatVector(s), true); + } +} + +#endif + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index d22c8587a82b5..22da226bbb7d9 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -378,6 +378,8 @@ def generate_build_tree( types_to_disable = args.disable_types # enable/disable float 8 types disable_float8_types = args.android or ("float8" in types_to_disable) + # enable/disable float 4 type + disable_float4_types = args.android or args.use_rocm or ("float4" in types_to_disable) disable_optional_type = "optional" in types_to_disable disable_sparse_tensors = "sparsetensor" in types_to_disable if is_windows(): @@ -516,6 +518,7 @@ def generate_build_tree( "-Donnxruntime_USE_WEBNN=" + ("ON" if args.use_webnn else "OFF"), "-Donnxruntime_USE_CANN=" + ("ON" if args.use_cann else "OFF"), "-Donnxruntime_DISABLE_FLOAT8_TYPES=" + ("ON" if disable_float8_types else "OFF"), + "-Donnxruntime_DISABLE_FLOAT4_TYPES=" + ("ON" if disable_float4_types else "OFF"), "-Donnxruntime_DISABLE_SPARSE_TENSORS=" + ("ON" if disable_sparse_tensors else "OFF"), "-Donnxruntime_DISABLE_OPTIONAL_TYPE=" + ("ON" if disable_optional_type else "OFF"), "-Donnxruntime_CUDA_MINIMAL=" + ("ON" if args.enable_cuda_minimal_build else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index c42f8e3219da4..3736eb56bae2e 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -517,7 +517,7 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: "--disable_types", nargs="+", default=[], - choices=["float8", "optional", "sparsetensor"], + choices=["float4", "float8", "optional", "sparsetensor"], help="Disable selected data types.", ) parser.add_argument( From 077dd2a49c6bab45ea434d0e37d5cf35ed7151eb Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 3 Sep 2025 18:20:38 -0700 Subject: [PATCH 2/3] [Core] Fix debug node input output compilation after Fp4 support was enabled in ORT (#25940) ### Description As title ### Motivation and Context Follow-up fixes to https://github.com/microsoft/onnxruntime/pull/25767/ --- .../framework/print_tensor_statistics_utils.h | 53 +++--- .../core/framework/print_tensor_utils.h | 174 ++++++++++-------- 2 files changed, 121 insertions(+), 106 deletions(-) diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h index c2030424ef19d..64d60e048a112 100644 --- a/onnxruntime/core/framework/print_tensor_statistics_utils.h +++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h @@ -94,33 +94,36 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_ } } -#define DEF_PRINT_COMMON_STATS_INT4(INT4_TYPE) \ - template <> \ - inline void PrintCommonStats( \ - const INT4_TYPE* data, size_t count, TensorStatisticsData&) { \ - using UnpackedType = typename INT4_TYPE::UnpackedType; \ - UnpackedType min = data[0].GetElem(0); \ - UnpackedType max = min; \ - for (size_t i = 1; i < count; i++) { \ - auto indices = INT4_TYPE::GetTensorElemIndices(i); \ - auto value = data[indices.first].GetElem(indices.second); \ - if (value > max) { \ - max = value; \ - } \ - if (value < min) { \ - min = value; \ - } \ - } \ - \ - std::cout << "Min="; \ - PrintValue(min); \ - \ - std::cout << ",Max="; \ - PrintValue(max); \ +#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \ + template <> \ + inline void PrintCommonStats( \ + const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \ + using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \ + UnpackedType min = data[0].GetElem(0); \ + UnpackedType max = min; \ + for (size_t i = 1; i < count; i++) { \ + auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \ + auto value = data[indices.first].GetElem(indices.second); \ + if (value > max) { \ + max = value; \ + } \ + if (value < min) { \ + min = value; \ + } \ + } \ + \ + std::cout << "Min="; \ + PrintValue(min); \ + \ + std::cout << ",Max="; \ + PrintValue(max); \ } -DEF_PRINT_COMMON_STATS_INT4(Int4x2) -DEF_PRINT_COMMON_STATS_INT4(UInt4x2) +DEF_PRINT_COMMON_STATS_4BIT(Int4x2) +DEF_PRINT_COMMON_STATS_4BIT(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2) +#endif template void PrintHalfStats(const T* data, size_t count) { diff --git a/onnxruntime/core/framework/print_tensor_utils.h b/onnxruntime/core/framework/print_tensor_utils.h index e6af5e9e5832b..47be8b8dc2057 100644 --- a/onnxruntime/core/framework/print_tensor_utils.h +++ b/onnxruntime/core/framework/print_tensor_utils.h @@ -74,28 +74,31 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t std::cout << std::endl; } -// INT4 - Print snippet of 2D tensor with shape (dim0, dim1) -#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(INT4_TYPE) \ - template <> \ - inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, \ - int64_t edge_items) { \ - for (int64_t i = 0; i < dim0; i++) { \ - SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ - auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t j = 1; j < dim1; j++) { \ - SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \ - std::cout << ", "; \ - indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// 4 BIT TYPE - Print snippet of 2D tensor with shape (dim0, dim1) +#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(FOUR_BIT_TYPE) \ + template <> \ + inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, \ + int64_t edge_items) { \ + for (int64_t i = 0; i < dim0; i++) { \ + SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ + auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t j = 1; j < dim1; j++) { \ + SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \ + std::cout << ", "; \ + indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(Int4x2) -DEF_PRINT_CPU_TENSOR_SNIPPET_2D_INT4(UInt4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Int4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Float4E2M1x2) +#endif // Print snippet of 3D tensor with shape (dim0, dim1, dim2) template @@ -117,32 +120,35 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t std::cout << std::endl; } -// INT4 - Print snippet of 3D tensor with shape (dim0, dim1, dim2) -#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(INT4_TYPE) \ - template <> \ - inline void PrintCpuTensorSnippet(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \ - int64_t edge_items) { \ - for (int64_t i = 0; i < dim0; i++) { \ - SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ - for (int64_t j = 0; j < dim1; j++) { \ - SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \ - auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t k = 1; k < dim2; k++) { \ - SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \ - std::cout << ", "; \ - indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// 4 BIT TYPE - Print snippet of 3D tensor with shape (dim0, dim1, dim2) +#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(FOUR_BIT_TYPE) \ + template <> \ + inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \ + int64_t edge_items) { \ + for (int64_t i = 0; i < dim0; i++) { \ + SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ + for (int64_t j = 0; j < dim1; j++) { \ + SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \ + auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t k = 1; k < dim2; k++) { \ + SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \ + std::cout << ", "; \ + indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(Int4x2) -DEF_PRINT_CPU_TENSOR_SNIPPET_3D_INT4(UInt4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Int4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Float4E2M1x2) +#endif // Print 2D tensor template @@ -158,25 +164,28 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1) { std::cout << std::endl; } -// INT4 - Print 2D tensor -#define DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(INT4_TYPE) \ - template <> \ - inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1) { \ - for (int64_t i = 0; i < dim0; i++) { \ - auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t j = 1; j < dim1; j++) { \ - std::cout << ", "; \ - indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// 4 BIT TYPE - Print 2D tensor +#define DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(FOUR_BIT_TYPE) \ + template <> \ + inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1) { \ + for (int64_t i = 0; i < dim0; i++) { \ + auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t j = 1; j < dim1; j++) { \ + std::cout << ", "; \ + indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(Int4x2) -DEF_PRINT_CPU_TENSOR_FULL_2D_INT4(UInt4x2) +DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Int4x2) +DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Float4E2M1x2) +#endif // Print 3D tensor template @@ -195,28 +204,31 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim std::cout << std::endl; } -// INT4 - Print 3D tensor -#define DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(INT4_TYPE) \ - template <> \ - inline void PrintCpuTensorFull(const INT4_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \ - for (int64_t i = 0; i < dim0; i++) { \ - for (int64_t j = 0; j < dim1; j++) { \ - auto indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t k = 1; k < dim2; k++) { \ - std::cout << ", "; \ - indices = INT4_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// 4 BIT TYPE - Print 3D tensor +#define DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(FOUR_BIT_TYPE) \ + template <> \ + inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \ + for (int64_t i = 0; i < dim0; i++) { \ + for (int64_t j = 0; j < dim1; j++) { \ + auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t k = 1; k < dim2; k++) { \ + std::cout << ", "; \ + indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(Int4x2) -DEF_PRINT_CPU_TENSOR_FULL_3D_INT4(UInt4x2) +DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Int4x2) +DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(UInt4x2) +#if !defined(DISABLE_FLOAT4_TYPES) +DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Float4E2M1x2) +#endif template void PrintCpuTensor(const onnxruntime::Tensor& tensor, From 1a87ee2fb8d027c7c129ae9dc4815d7cbaa77a6b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 1 Oct 2025 18:48:20 +0000 Subject: [PATCH 3/3] Link FP4 types between OnnxRT and MIGraphX APIs Do this so that MIGraphX can take in fp4 types from input/output tensors and then use that to perform an inference via the MIGraphX API. --- .../core/providers/migraphx/migraphx_execution_provider.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index a59347841be95..239a5054801bc 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -268,6 +268,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT4E2M1: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: @@ -318,6 +319,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: + mgx_type = migraphx_shape_fp4x2_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: mgx_type = migraphx_shape_int8_type; break;