Skip to content

Commit c850347

Browse files
committed
check data type using onnx tensors
Signed-off-by: Purushothaman Saravanan <[email protected]>
1 parent 75ba7cd commit c850347

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ def _get_scale_and_zp(
537537
node: onnx.NodeProto,
538538
initializers: dict[str, onnx.TensorProto],
539539
tensor_producers: dict[str, onnx.NodeProto],
540-
) -> tuple[np.ndarray, np.ndarray]:
540+
) -> tuple[onnx.TensorProto, onnx.TensorProto]:
541541
"""Get scale and zero point tensors for a node.
542542
543543
Args:
@@ -546,7 +546,7 @@ def _get_scale_and_zp(
546546
tensor_producers: Dictionary of tensor producers
547547
548548
Returns:
549-
Tuple of (scale_array, zero_point_array)
549+
Tuple of (scale_tensor, zero_point_tensor)
550550
551551
Raises:
552552
ValueError: If scale or zero point cannot be found
@@ -560,7 +560,6 @@ def _get_scale_and_zp(
560560
if not producer or not producer.attribute:
561561
raise ValueError(f"Invalid scale producer for {scale_name}")
562562
scale = producer.attribute[0].t
563-
scale_array = onnx.numpy_helper.to_array(scale)
564563

565564
# Get zero point tensor
566565
zp_name = node.input[2]
@@ -571,9 +570,8 @@ def _get_scale_and_zp(
571570
if not producer or not producer.attribute:
572571
raise ValueError(f"Invalid zero point producer for {zp_name}")
573572
zp = producer.attribute[0].t
574-
zp_array = onnx.numpy_helper.to_array(zp)
575573

576-
return scale_array, zp_array
574+
return scale, zp
577575

578576

579577
def _get_successive_consumers(
@@ -611,16 +609,16 @@ def _get_successive_consumers(
611609

612610
def _convert_weight(
613611
weight_array: np.ndarray,
614-
scale_array: np.ndarray,
615-
zp_array: np.ndarray,
612+
scale: onnx.TensorProto,
613+
zp: onnx.TensorProto,
616614
quantized_node: onnx.NodeProto,
617615
) -> np.ndarray:
618616
"""Convert a weight tensor to INT8/FP8 format based on scale and zero point.
619617
620618
Args:
621619
weight_array: The weight tensor to convert
622-
scale_array: The scale tensor for quantization
623-
zp_array: The zero point tensor for quantization
620+
scale: The scale tensor for quantization
621+
zp: The zero point tensor for quantization
624622
quantized_node: The operation node that will use the converted weight
625623
626624
Returns:
@@ -637,6 +635,10 @@ def _convert_weight(
637635
weight_shape = weight_array.shape
638636
op_type = quantized_node.op_type
639637

638+
# Convert onnx tensors to numpy array
639+
scale_array = onnx.numpy_helper.to_array(scale)
640+
zp_array = onnx.numpy_helper.to_array(zp)
641+
640642
# Dynamically determine transB for Gemm
641643
trans_b = 0
642644
if op_type == "Gemm":
@@ -668,7 +670,7 @@ def _convert_weight(
668670
zp_array = zp_array.reshape(*reshape_dims)
669671

670672
# Convert to INT8/FP8
671-
if zp_array.dtype == onnx_dtype_map["Float8"]:
673+
if zp.data_type == onnx_dtype_map["Float8"]:
672674
scaled = np.asarray(weight_array / scale_array) + zp_array
673675
else:
674676
scaled = np.asarray((weight_array / scale_array).round())
@@ -709,7 +711,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
709711
def _create_fp8_tensor(scaled: np.ndarray, weight_name: str) -> onnx.TensorProto:
710712
"""Create a FLOAT8E4M3FN tensor directly from numpy array."""
711713
fp8_data = _cast_fp8(scaled)
712-
return onnx.numpy_helper.from_array(fp8_data, weight_name)
714+
tensor = onnx.numpy_helper.from_array(fp8_data, weight_name)
715+
tensor.data_type = onnx_dtype_map["Float8"]
716+
return tensor
713717

714718

715719
def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
@@ -761,16 +765,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
761765
weight_array = onnx.numpy_helper.to_array(weight)
762766

763767
# Get scale and zero point
764-
scale_array, zp_array = _get_scale_and_zp(node, initializers, tensor_producers)
768+
scale, zp = _get_scale_and_zp(node, initializers, tensor_producers)
765769

766770
# Validate Q->DQ->Op pattern and get consumers
767771
dq_node, quantized_node = _get_successive_consumers(node, tensor_consumers)
768772

769773
# Convert weight
770-
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
774+
scaled = _convert_weight(weight_array, scale, zp, quantized_node)
771775

772776
# Create and update new weight tensor
773-
if zp_array.dtype == onnx_dtype_map["Float8"]:
777+
if zp.data_type == onnx_dtype_map["Float8"]:
774778
new_weight = _create_fp8_tensor(scaled, weight_name)
775779
logger.debug(f"Converted {weight_name} to FP8")
776780
else:

0 commit comments

Comments
 (0)