diff --git a/modelopt/onnx/quantization/gs_patching.py b/modelopt/onnx/quantization/gs_patching.py index 8097d8a98..a0eea8495 100644 --- a/modelopt/onnx/quantization/gs_patching.py +++ b/modelopt/onnx/quantization/gs_patching.py @@ -39,7 +39,14 @@ def _make_constant( converted_dtype = ( dtype if isinstance(values, LazyValues) else onnx.helper.tensor_dtype_to_np_dtype(dtype) ) - if values.dtype != converted_dtype: + + # Allow int8/uint8 as intermediate representation for INT4/UINT4 + # INT4/UINT4 values are stored as int8/uint8 in numpy arrays and packed during export + is_valid_int4_intermediate = (dtype == onnx.TensorProto.INT4 and values.dtype == np.int8) or ( + dtype == onnx.TensorProto.UINT4 and values.dtype == np.uint8 + ) + + if not is_valid_int4_intermediate and values.dtype != converted_dtype: logger.error( f"Trying to create tensor with incompatible types: `{values.dtype}`, `{dtype}`" )