diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index b436af1b2..64e38f44e 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -184,6 +184,17 @@ def convert( logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}") logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}") + # Since we have removed all casts, we can pre-compute the tensor_to_consumers and + # tensor_to_producers maps since they will not change for the duration of the conversion. + tensor_to_consumers = defaultdict(list) + tensor_to_producers = defaultdict(list) + + for node in self.model.graph.node: + for input in node.input: + tensor_to_consumers[input].append(node) + for output in node.output: + tensor_to_producers[output].append(node) + # Add cast nodes for "cast_up" tensors for tensor_name in cast_up_tensors: exclude_consumers = low_precision_nodes @@ -194,7 +205,11 @@ def convert( set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name} ) self._add_cast( - tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers + tensor_name, + self.high_precision_type, + exclude_consumers=exclude_consumers, + tensor_to_consumers=tensor_to_consumers, + tensor_to_producers=tensor_to_producers, ) # Add cast nodes for "cast_down" tensors @@ -203,6 +218,8 @@ def convert( tensor_name, self.low_precision_type, exclude_consumers=high_precision_nodes, + tensor_to_consumers=tensor_to_consumers, + tensor_to_producers=tensor_to_producers, ) # Convert initializers to correct precision according to the consumer nodes @@ -803,7 +820,12 @@ def _remove_preexisting_casts(self) -> None: self.model.graph.node.remove(node) def _add_cast( - self, tensor_name: str, cast_to: PrecisionTypes, exclude_consumers: list[str] = [] + self, + tensor_name: str, + cast_to: PrecisionTypes, + exclude_consumers: list[str] = [], + tensor_to_consumers: dict[str, list[onnx.NodeProto]] | None = None, + tensor_to_producers: dict[str, list[onnx.NodeProto]] | None = None, ) -> None: """Adds a cast operation on a tensor and reconnects its consumers. @@ -811,6 +833,14 @@ def _add_cast( tensor_name: Name of the tensor to cast. cast_to: Target precision type to cast to. exclude_consumers: List of consumer nodes to exclude from reconnection. + tensor_to_consumers: Optional pre-computed map of tensor names to their consumer nodes. + If not provided, the map will be computed on the fly. + tensor_to_producers: Optional pre-computed map of tensor names to their producer nodes. + If not provided, the map will be computed on the fly. + + NOTE: It is up to the user to ensure that the tensor_to_consumers and tensor_to_producers + maps are up to date before calling this function. Consecutive casts in the graph will break + this assumption and the maps must be recomputed. """ # Empty tensors may have special handling in ONNX (e.g. for Resize scales) which can break when redundant casts # are injected. Since there's no data, it's safe to only update the metadata. @@ -848,7 +878,10 @@ def _add_cast( name=f"{tensor_name}_cast_to_{cast_to.str_short}", ) - consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name) + if tensor_to_consumers is None: + utils.get_consumer_nodes(self.model, tensor_name) + else: + consumer_nodes = tensor_to_consumers.get(tensor_name, []) consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers] for node in consumer_nodes: for i, input_name in enumerate(node.input): @@ -868,7 +901,10 @@ def _add_cast( break # Find producer node to insert cast after it - producer_nodes = utils.get_producer_nodes(self.model, tensor_name) + if tensor_to_producers is None: + producer_nodes = utils.get_producer_nodes(self.model, tensor_name) + else: + producer_nodes = tensor_to_producers.get(tensor_name, []) if producer_nodes: # Insert after the producer node # Find index by iterating since RepeatedCompositeContainer doesn't support index()