@@ -541,7 +541,7 @@ def _get_scale_and_zp(
541541 node : onnx .NodeProto ,
542542 initializers : dict [str , onnx .TensorProto ],
543543 tensor_producers : dict [str , onnx .NodeProto ],
544- ) -> tuple [np . ndarray , np . ndarray ]:
544+ ) -> tuple [onnx . TensorProto , onnx . TensorProto ]:
545545 """Get scale and zero point tensors for a node.
546546
547547 Args:
@@ -550,7 +550,7 @@ def _get_scale_and_zp(
550550 tensor_producers: Dictionary of tensor producers
551551
552552 Returns:
553- Tuple of (scale_array, zero_point_array )
553+ Tuple of (scale_tensor, zero_point_tensor )
554554
555555 Raises:
556556 ValueError: If scale or zero point cannot be found
@@ -564,7 +564,6 @@ def _get_scale_and_zp(
564564 if not producer or not producer .attribute :
565565 raise ValueError (f"Invalid scale producer for { scale_name } " )
566566 scale = producer .attribute [0 ].t
567- scale_array = onnx .numpy_helper .to_array (scale )
568567
569568 # Get zero point tensor
570569 zp_name = node .input [2 ]
@@ -575,9 +574,8 @@ def _get_scale_and_zp(
575574 if not producer or not producer .attribute :
576575 raise ValueError (f"Invalid zero point producer for { zp_name } " )
577576 zp = producer .attribute [0 ].t
578- zp_array = onnx .numpy_helper .to_array (zp )
579577
580- return scale_array , zp_array
578+ return scale , zp
581579
582580
583581def _get_successive_consumers (
@@ -615,16 +613,16 @@ def _get_successive_consumers(
615613
616614def _convert_weight (
617615 weight_array : np .ndarray ,
618- scale_array : np . ndarray ,
619- zp_array : np . ndarray ,
616+ scale : onnx . TensorProto ,
617+ zp : onnx . TensorProto ,
620618 quantized_node : onnx .NodeProto ,
621619) -> np .ndarray :
622620 """Convert a weight tensor to INT8/FP8 format based on scale and zero point.
623621
624622 Args:
625623 weight_array: The weight tensor to convert
626- scale_array : The scale tensor for quantization
627- zp_array : The zero point tensor for quantization
624+ scale : The scale tensor for quantization
625+ zp : The zero point tensor for quantization
628626 quantized_node: The operation node that will use the converted weight
629627
630628 Returns:
@@ -641,6 +639,10 @@ def _convert_weight(
641639 weight_shape = weight_array .shape
642640 op_type = quantized_node .op_type
643641
642+ # Convert onnx tensors to numpy array
643+ scale_array = onnx .numpy_helper .to_array (scale )
644+ zp_array = onnx .numpy_helper .to_array (zp )
645+
644646 # Dynamically determine transB for Gemm
645647 trans_b = 0
646648 if op_type == "Gemm" :
@@ -672,7 +674,7 @@ def _convert_weight(
672674 zp_array = zp_array .reshape (* reshape_dims )
673675
674676 # Convert to INT8/FP8
675- if zp_array . dtype == onnx_dtype_map ["Float8" ]:
677+ if zp . data_type == onnx_dtype_map ["Float8" ]:
676678 scaled = np .asarray (weight_array / scale_array ) + zp_array
677679 else :
678680 scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -713,7 +715,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
713715def _create_fp8_tensor (scaled : np .ndarray , weight_name : str ) -> onnx .TensorProto :
714716 """Create a FLOAT8E4M3FN tensor directly from numpy array."""
715717 fp8_data = _cast_fp8 (scaled )
716- return onnx .numpy_helper .from_array (fp8_data , weight_name )
718+ tensor = onnx .numpy_helper .from_array (fp8_data , weight_name )
719+ tensor .data_type = onnx_dtype_map ["Float8" ]
720+ return tensor
717721
718722
719723def qdq_to_dq (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
@@ -765,16 +769,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
765769 weight_array = onnx .numpy_helper .to_array (weight )
766770
767771 # Get scale and zero point
768- scale_array , zp_array = _get_scale_and_zp (node , initializers , tensor_producers )
772+ scale , zp = _get_scale_and_zp (node , initializers , tensor_producers )
769773
770774 # Validate Q->DQ->Op pattern and get consumers
771775 dq_node , quantized_node = _get_successive_consumers (node , tensor_consumers )
772776
773777 # Convert weight
774- scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
778+ scaled = _convert_weight (weight_array , scale , zp , quantized_node )
775779
776780 # Create and update new weight tensor
777- if zp_array . dtype == onnx_dtype_map ["Float8" ]:
781+ if zp . data_type == onnx_dtype_map ["Float8" ]:
778782 new_weight = _create_fp8_tensor (scaled , weight_name )
779783 logger .debug (f"Converted { weight_name } to FP8" )
780784 else :
0 commit comments