diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0d66e69c8e925..73b38afff34f3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -242,8 +242,8 @@ Do not modify directly.* |||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[1, 8]|**T** = tensor(double), tensor(float)| |MatMulInteger|*in* A:**T1**
*in* B:**T2**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*out* Y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int32)| -|Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Max|*in* data_0:**T**
*out* max:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[8, 11]|**T** = tensor(double), tensor(float)| |||[6, 7]|**T** = tensor(float)| |MaxPool|*in* X:**T**
*out* Y:**T**

or

*in* X:**T**
*out* Y:**T**
*out* Indices:**I**|22+|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int8), tensor(uint8)| @@ -263,8 +263,8 @@ Do not modify directly.* |MelWeightMatrix|*in* num_mel_bins:**T1**
*in* dft_length:**T1**
*in* sample_rate:**T1**
*in* lower_edge_hertz:**T2**
*in* upper_edge_hertz:**T2**
*out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |MemcpyFromHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |MemcpyToHost|*in* X:**T**
*out* Y:**T**|1+|**T** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| +|Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| +|||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[8, 11]|**T** = tensor(double), tensor(float)| |||[6, 7]|**T** = tensor(float)| |Mod|*in* A:**T**
*in* B:**T**
*out* C:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 72300d028501a..b940d71e1165e 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -20,14 +20,16 @@ namespace op_kernel_type_control { ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 8, Input, 0, float, double); ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0, - float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); + float, double, MLFloat16, int8_t, int32_t, uint32_t, + int64_t, uint8_t, uint64_t); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Max, 12, Input, 0, int32_t, int64_t); // Min ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 8, Input, 0, float, double); ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0, - float, double, MLFloat16, int32_t, uint32_t, int64_t, uint64_t); + float, double, MLFloat16, int8_t, int32_t, uint32_t, + int64_t, uint8_t, uint64_t); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES(kCpuExecutionProvider, kOnnxDomain, Min, 12, Input, 0, int32_t, int64_t); @@ -989,7 +991,8 @@ Status Min_8::Compute(OpKernelContext* context) const { return MinMaxMLFloat16(*this, context); break; default: - utils::MLTypeCallDispatcher + utils::MLTypeCallDispatcher t_disp(dt_type); return t_disp.InvokeRet(*this, context); } @@ -1055,7 +1058,8 @@ Status Max_8::Compute(OpKernelContext* context) const { return MinMaxMLFloat16(*this, context); break; default: - utils::MLTypeCallDispatcher + utils::MLTypeCallDispatcher t_disp(dt_type); return t_disp.InvokeRet(*this, context); } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index d848865f719b8..cbb8ca43e8f06 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3700,6 +3700,74 @@ TEST(MathOpTest, Equal_multidirectional_broadcastAB_bool) { test.Run(); } +TEST(MathOpTest, Max_12_Int8) { + OpTester test("Max", 12); + test.AddInput("data_0", {1, 3}, + {1, 2, 3}); + test.AddInput("data_2", {3, 3}, + {10, 20, 30, + 40, 50, 60, + 70, 80, 90}); + test.AddInput("data_1", {3, 1}, + {-1, -2, 127}); + test.AddOutput("max", {3, 3}, + {10, 20, 30, + 40, 50, 60, + 127, 127, 127}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(MathOpTest, Max_12_UInt8) { + OpTester test("Max", 12); + test.AddInput("data_0", {1, 3}, + {1, 20, 30}); + test.AddInput("data_2", {3, 3}, + {10, 20, 30, + 40, 50, 60, + 70, 80, 90}); + test.AddInput("data_1", {3, 1}, + {100, 20, 30}); + test.AddOutput("max", {3, 3}, + {100, 100, 100, + 40, 50, 60, + 70, 80, 90}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(MathOpTest, Min_12_Int8) { + OpTester test("Min", 12); + test.AddInput("data_0", {1, 3}, + {1, 2, 3}); + test.AddInput("data_2", {3, 3}, + {10, 20, 30, + 40, 50, 60, + -70, -80, -90}); + test.AddInput("data_1", {3, 1}, + {-1, 20, 127}); + test.AddOutput("min", {3, 3}, + {-1, -1, -1, + 1, 2, 3, + -70, -80, -90}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(MathOpTest, Min_12_UInt8) { + OpTester test("Min", 12); + test.AddInput("data_0", {1, 3}, + {1, 20, 30}); + test.AddInput("data_2", {3, 3}, + {10, 20, 30, + 40, 50, 60, + 70, 80, 90}); + test.AddInput("data_1", {3, 1}, + {1, 20, 30}); + test.AddOutput("min", {3, 3}, + {1, 1, 1, + 1, 20, 20, + 1, 20, 30}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(MathOpTest, Mean_6) { OpTester test("Mean", 6); std::vector dims{3, 3};