Skip to content

Commit fc22673

Browse files
hthadicherlakevalmorabia97
authored andcommitted
Added exception for warning caused while creating int4 tensor (#461)
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent d639e74 commit fc22673

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@ def _make_constant(
3939
converted_dtype = (
4040
dtype if isinstance(values, LazyValues) else onnx.helper.tensor_dtype_to_np_dtype(dtype)
4141
)
42-
if values.dtype != converted_dtype:
42+
43+
# Allow int8/uint8 as intermediate representation for INT4/UINT4
44+
# INT4/UINT4 values are stored as int8/uint8 in numpy arrays and packed during export
45+
is_valid_int4_intermediate = (dtype == onnx.TensorProto.INT4 and values.dtype == np.int8) or (
46+
dtype == onnx.TensorProto.UINT4 and values.dtype == np.uint8
47+
)
48+
49+
if not is_valid_int4_intermediate and values.dtype != converted_dtype:
4350
logger.error(
4451
f"Trying to create tensor with incompatible types: `{values.dtype}`, `{dtype}`"
4552
)

0 commit comments

Comments
 (0)