Skip to content

Commit 072e837

Browse files
committed
check data type using onnx tensors
Signed-off-by: Purushothaman Saravanan <[email protected]>
1 parent 51a45dc commit 072e837

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def _get_scale_and_zp(
541541
node: onnx.NodeProto,
542542
initializers: dict[str, onnx.TensorProto],
543543
tensor_producers: dict[str, onnx.NodeProto],
544-
) -> tuple[np.ndarray, np.ndarray]:
544+
) -> tuple[onnx.TensorProto, onnx.TensorProto]:
545545
"""Get scale and zero point tensors for a node.
546546
547547
Args:
@@ -550,7 +550,7 @@ def _get_scale_and_zp(
550550
tensor_producers: Dictionary of tensor producers
551551
552552
Returns:
553-
Tuple of (scale_array, zero_point_array)
553+
Tuple of (scale_tensor, zero_point_tensor)
554554
555555
Raises:
556556
ValueError: If scale or zero point cannot be found
@@ -564,7 +564,6 @@ def _get_scale_and_zp(
564564
if not producer or not producer.attribute:
565565
raise ValueError(f"Invalid scale producer for {scale_name}")
566566
scale = producer.attribute[0].t
567-
scale_array = onnx.numpy_helper.to_array(scale)
568567

569568
# Get zero point tensor
570569
zp_name = node.input[2]
@@ -575,9 +574,8 @@ def _get_scale_and_zp(
575574
if not producer or not producer.attribute:
576575
raise ValueError(f"Invalid zero point producer for {zp_name}")
577576
zp = producer.attribute[0].t
578-
zp_array = onnx.numpy_helper.to_array(zp)
579577

580-
return scale_array, zp_array
578+
return scale, zp
581579

582580

583581
def _get_successive_consumers(
@@ -615,16 +613,16 @@ def _get_successive_consumers(
615613

616614
def _convert_weight(
617615
weight_array: np.ndarray,
618-
scale_array: np.ndarray,
619-
zp_array: np.ndarray,
616+
scale: onnx.TensorProto,
617+
zp: onnx.TensorProto,
620618
quantized_node: onnx.NodeProto,
621619
) -> np.ndarray:
622620
"""Convert a weight tensor to INT8/FP8 format based on scale and zero point.
623621
624622
Args:
625623
weight_array: The weight tensor to convert
626-
scale_array: The scale tensor for quantization
627-
zp_array: The zero point tensor for quantization
624+
scale: The scale tensor for quantization
625+
zp: The zero point tensor for quantization
628626
quantized_node: The operation node that will use the converted weight
629627
630628
Returns:
@@ -641,6 +639,10 @@ def _convert_weight(
641639
weight_shape = weight_array.shape
642640
op_type = quantized_node.op_type
643641

642+
# Convert onnx tensors to numpy array
643+
scale_array = onnx.numpy_helper.to_array(scale)
644+
zp_array = onnx.numpy_helper.to_array(zp)
645+
644646
# Dynamically determine transB for Gemm
645647
trans_b = 0
646648
if op_type == "Gemm":
@@ -672,7 +674,7 @@ def _convert_weight(
672674
zp_array = zp_array.reshape(*reshape_dims)
673675

674676
# Convert to INT8/FP8
675-
if zp_array.dtype == onnx_dtype_map["Float8"]:
677+
if zp.data_type == onnx_dtype_map["Float8"]:
676678
scaled = np.asarray(weight_array / scale_array) + zp_array
677679
else:
678680
scaled = np.asarray((weight_array / scale_array).round())
@@ -713,7 +715,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
713715
def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
714716
"""Create a FLOAT8E4M3FN tensor directly from numpy array."""
715717
fp8_data = _cast_fp8(scaled)
716-
return onnx.numpy_helper.from_array(fp8_data, weight_name)
718+
tensor = onnx.numpy_helper.from_array(fp8_data, weight_name)
719+
tensor.data_type = onnx_dtype_map["Float8"]
720+
return tensor
717721

718722

719723
def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
@@ -765,16 +769,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
765769
weight_array = onnx.numpy_helper.to_array(weight)
766770

767771
# Get scale and zero point
768-
scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers)
772+
scale, zp = _get_scale_and_zp(node, initializers, tensor_producers)
769773

770774
# Validate Q->DQ->Op pattern and get consumers
771775
dq_node, quantized_node = _get_successive_consumers(node, tensor_consumers)
772776

773777
# Convert weight
774-
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
778+
scaled = _convert_weight(weight_array, scale, zp, quantized_node)
775779

776780
# Create and update new weight tensor
777-
if zp_array.dtype == onnx_dtype_map["Float8"]:
781+
if zp.data_type == onnx_dtype_map["Float8"]:
778782
new_weight = _create_fp8_tensor(scaled, weight_name)
779783
logger.debug(f"Converted {weight_name} to FP8")
780784
else:

0 commit comments

Comments
 (0)