@@ -537,7 +537,7 @@ def _get_scale_and_zp(
537537 node : onnx .NodeProto ,
538538 initializers : dict [str , onnx .TensorProto ],
539539 tensor_producers : dict [str , onnx .NodeProto ],
540- ) -> tuple [np . ndarray , np . ndarray ]:
540+ ) -> tuple [onnx . TensorProto , onnx . TensorProto ]:
541541 """Get scale and zero point tensors for a node.
542542
543543 Args:
@@ -546,7 +546,7 @@ def _get_scale_and_zp(
546546 tensor_producers: Dictionary of tensor producers
547547
548548 Returns:
549- Tuple of (scale_array, zero_point_array )
549+ Tuple of (scale_tensor, zero_point_tensor )
550550
551551 Raises:
552552 ValueError: If scale or zero point cannot be found
@@ -560,7 +560,6 @@ def _get_scale_and_zp(
560560 if not producer or not producer .attribute :
561561 raise ValueError (f"Invalid scale producer for { scale_name } " )
562562 scale = producer .attribute [0 ].t
563- scale_array = onnx .numpy_helper .to_array (scale )
564563
565564 # Get zero point tensor
566565 zp_name = node .input [2 ]
@@ -571,9 +570,8 @@ def _get_scale_and_zp(
571570 if not producer or not producer .attribute :
572571 raise ValueError (f"Invalid zero point producer for { zp_name } " )
573572 zp = producer .attribute [0 ].t
574- zp_array = onnx .numpy_helper .to_array (zp )
575573
576- return scale_array , zp_array
574+ return scale , zp
577575
578576
579577def _get_successive_consumers (
@@ -611,16 +609,16 @@ def _get_successive_consumers(
611609
612610def _convert_weight (
613611 weight_array : np .ndarray ,
614- scale_array : np . ndarray ,
615- zp_array : np . ndarray ,
612+ scale : onnx . TensorProto ,
613+ zp : onnx . TensorProto ,
616614 quantized_node : onnx .NodeProto ,
617615) -> np .ndarray :
618616 """Convert a weight tensor to INT8/FP8 format based on scale and zero point.
619617
620618 Args:
621619 weight_array: The weight tensor to convert
622- scale_array : The scale tensor for quantization
623- zp_array : The zero point tensor for quantization
620+ scale : The scale tensor for quantization
621+ zp : The zero point tensor for quantization
624622 quantized_node: The operation node that will use the converted weight
625623
626624 Returns:
@@ -637,6 +635,10 @@ def _convert_weight(
637635 weight_shape = weight_array .shape
638636 op_type = quantized_node .op_type
639637
638+ # Convert onnx tensors to numpy array
639+ scale_array = onnx .numpy_helper .to_array (scale )
640+ zp_array = onnx .numpy_helper .to_array (zp )
641+
640642 # Dynamically determine transB for Gemm
641643 trans_b = 0
642644 if op_type == "Gemm" :
@@ -668,7 +670,7 @@ def _convert_weight(
668670 zp_array = zp_array .reshape (* reshape_dims )
669671
670672 # Convert to INT8/FP8
671- if zp_array . dtype == onnx_dtype_map ["Float8" ]:
673+ if zp . data_type == onnx_dtype_map ["Float8" ]:
672674 scaled = np .asarray (weight_array / scale_array ) + zp_array
673675 else :
674676 scaled = np .asarray ((weight_array / scale_array ).round ())
@@ -709,7 +711,9 @@ def _cast_fp4(array: np.ndarray) -> np.ndarray:
709711def _create_fp8_tensor (scaled : np .ndarray , weight_name : str ) -> onnx .TensorProto :
710712 """Create a FLOAT8E4M3FN tensor directly from numpy array."""
711713 fp8_data = _cast_fp8 (scaled )
712- return onnx .numpy_helper .from_array (fp8_data , weight_name )
714+ tensor = onnx .numpy_helper .from_array (fp8_data , weight_name )
715+ tensor .data_type = onnx_dtype_map ["Float8" ]
716+ return tensor
713717
714718
715719def qdq_to_dq (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
@@ -761,16 +765,16 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
761765 weight_array = onnx .numpy_helper .to_array (weight )
762766
763767 # Get scale and zero point
764- scale_array , zp_array = _get_scale_and_zp (node , initializers , tensor_producers )
768+ scale , zp = _get_scale_and_zp (node , initializers , tensor_producers )
765769
766770 # Validate Q->DQ->Op pattern and get consumers
767771 dq_node , quantized_node = _get_successive_consumers (node , tensor_consumers )
768772
769773 # Convert weight
770- scaled = _convert_weight (weight_array , scale_array , zp_array , quantized_node )
774+ scaled = _convert_weight (weight_array , scale , zp , quantized_node )
771775
772776 # Create and update new weight tensor
773- if zp_array . dtype == onnx_dtype_map ["Float8" ]:
777+ if zp . data_type == onnx_dtype_map ["Float8" ]:
774778 new_weight = _create_fp8_tensor (scaled , weight_name )
775779 logger .debug (f"Converted { weight_name } to FP8" )
776780 else :
0 commit comments