Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ def convert(
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")

tensor_to_consumers = defaultdict(list)
tensor_to_producers = defaultdict(list)

for node in self.model.graph.node:
for input in node.input:
tensor_to_consumers[input].append(node)
for output in node.output:
tensor_to_producers[output].append(node)
Comment on lines +189 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed more efficient, but it's a bit risky. It relies on the assumption that the graph is static, meaning the list of consumers and producers for each tensor is constant. However, when we inject cast nodes, we render this data invalid.

In the limited scope where this is applied, it's probably OK, because we call self._remove_preexisting_casts() which prevents "chains" of cast nodes (cast->cast->cast). But we cannot replace all calls to utils.get_consumer_nodes and utils.get_consumer_nodes, nor can we modify all _add_cast instances (which you probably know, because you avoided that in this PR).
I think this at least warrants a comment to warn unsuspecting developers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we can only do this optimization since there are no cast nodes in the graph and the producers/consumers and guaranteed not to be affected from iteration to iteration, and it's why I made sure to keep the params optional for this specific case.

It's up to you how you want to proceed, we can:

  1. Keep as is, I added a warning to other devs, I feel that devs using this function can just leave those optional params empty to keep the safe behavior.
  2. We write a separate function or move this logic out of _add_cast to make it super explicit for this use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comments. Approved.


# Add cast nodes for "cast_up" tensors
for tensor_name in cast_up_tensors:
exclude_consumers = low_precision_nodes
Expand All @@ -194,7 +203,11 @@ def convert(
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
)
self._add_cast(
tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers
tensor_name,
self.high_precision_type,
exclude_consumers=exclude_consumers,
tensor_to_consumers=tensor_to_consumers,
tensor_to_producers=tensor_to_producers,
)

# Add cast nodes for "cast_down" tensors
Expand All @@ -203,6 +216,8 @@ def convert(
tensor_name,
self.low_precision_type,
exclude_consumers=high_precision_nodes,
tensor_to_consumers=tensor_to_consumers,
tensor_to_producers=tensor_to_producers,
)

# Convert initializers to correct precision according to the consumer nodes
Expand Down Expand Up @@ -803,14 +818,23 @@ def _remove_preexisting_casts(self) -> None:
self.model.graph.node.remove(node)

def _add_cast(
self, tensor_name: str, cast_to: PrecisionTypes, exclude_consumers: list[str] = []
self,
tensor_name: str,
cast_to: PrecisionTypes,
exclude_consumers: list[str] = [],
tensor_to_consumers: dict[str, list[onnx.NodeProto]] | None = None,
tensor_to_producers: dict[str, list[onnx.NodeProto]] | None = None,
) -> None:
"""Adds a cast operation on a tensor and reconnects its consumers.

Args:
tensor_name: Name of the tensor to cast.
cast_to: Target precision type to cast to.
exclude_consumers: List of consumer nodes to exclude from reconnection.
tensor_to_consumers: Optional pre-computed map of tensor names to their consumer nodes.
If not provided, the map will be computed on the fly.
tensor_to_producers: Optional pre-computed map of tensor names to their producer nodes.
If not provided, the map will be computed on the fly.
"""
# Empty tensors may have special handling in ONNX (e.g. for Resize scales) which can break when redundant casts
# are injected. Since there's no data, it's safe to only update the metadata.
Expand Down Expand Up @@ -848,7 +872,10 @@ def _add_cast(
name=f"{tensor_name}_cast_to_{cast_to.str_short}",
)

consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name)
if tensor_to_consumers is None:
utils.get_consumer_nodes(self.model, tensor_name)
else:
consumer_nodes = tensor_to_consumers.get(tensor_name, [])
consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers]
for node in consumer_nodes:
for i, input_name in enumerate(node.input):
Expand All @@ -868,7 +895,10 @@ def _add_cast(
break

# Find producer node to insert cast after it
producer_nodes = utils.get_producer_nodes(self.model, tensor_name)
if tensor_to_producers is None:
producer_nodes = utils.get_producer_nodes(self.model, tensor_name)
else:
producer_nodes = tensor_to_producers.get(tensor_name, [])
if producer_nodes:
# Insert after the producer node
# Find index by iterating since RepeatedCompositeContainer doesn't support index()
Expand Down