Skip to content

Commit 63bb3b5

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Support torch.int32 as a dtype for quantize and dequantize (#289)
Summary: Pull Request resolved: #289 The ops like `quantized_decomposed.quantize_per_tensor.default` did not support an int32 quantized type. Add support for these to the portable and aten runtimes. This is important for Turing which uses int32 to represent uint16 (as the latter is not a valid pytorch dtype). Reviewed By: kimishpatel Differential Revision: D49202048 fbshipit-source-id: 0faa89ce1d34b60ece443fb02fa14f02abf2d376
1 parent fbbec00 commit 63bb3b5

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

kernels/quantized/cpu/op_dequantize.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ void check_dequantize_per_tensor_args(
3737
Tensor& out) {
3838
ET_CHECK_MSG(
3939
input.scalar_type() == ScalarType::Byte ||
40-
input.scalar_type() == ScalarType::Char,
40+
input.scalar_type() == ScalarType::Char ||
41+
input.scalar_type() == ScalarType::Short ||
42+
input.scalar_type() == ScalarType::Int,
4143
"input.scalar_type() %hdd is not supported:",
4244
input.scalar_type());
4345

kernels/quantized/cpu/op_quantize.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ void check_quantize_per_tensor_args(
5656
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
5757
quant_max_upper_bound =
5858
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
59+
} else if (dtype == ScalarType::Short) {
60+
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
61+
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
62+
} else if (dtype == ScalarType::Int) {
63+
quant_min_lower_bound = std::numeric_limits<int32_t>::min();
64+
quant_max_upper_bound = std::numeric_limits<int32_t>::max();
5965
} else {
6066
ET_CHECK_MSG(false, "Unsupported dtype: %hdd", out_dtype);
6167
}

0 commit comments

Comments
 (0)