Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ def __init__(
self._sizes: dict[str, int] = {}
self._modified: bool = False
self._state = OptimizerState()
# Count of unknown (None) symbolic dimensions seen so far for generating unique names
self._unknown_dim_count = 0
self._reset()

def _reset(self) -> None:
Expand All @@ -982,6 +984,7 @@ def _reset(self) -> None:
self._sizes = {}
self._modified = False
self._state = OptimizerState()
self._unknown_dim_count = 0

def _do_inference(self, node: ir.Node) -> None:
output_types = {}
Expand Down Expand Up @@ -1029,7 +1032,15 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
inferred_type
)
output.shape = _merge_shapes(output.shape, inferred_shape)
merged_shape = _merge_shapes(output.shape, inferred_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean _merge_shapes propagating sym_dim? It looks like more of filling in shape info if it's a known int.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so:

if dim1.value is None:
return dim2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this specific case, I don't see how dim2 can provide meaningful sym_dim, as it's from onnxtype inference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong.

Copy link
Collaborator Author

@justinchuby justinchuby Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought dim2 was the original output dim, which may be from pytorch?

Copy link
Collaborator Author

@justinchuby justinchuby Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh dim1 is. Maybe dim1 should be the preferred shape then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you must be right in the dim merging case. However:

If output.shape is None, then we can take inferred_shape; If output.shape is not None, we will keep it. We probably called _merge_shapes here to for a robust logic that is shared.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea of merging shape information coming from two sources (assuming both to be correct) is a useful one. No point in specializing any further with assumptions about which of the two sources will have what information (because it is not needed). In particular, we can't and need not assume that the existing output.shape comes from pytorch exporter .... it might have been introduced by some optimization rule.

The only special case to be handled is if two different symbolic dims are used in the two different shapes for same dim. For now, we choose the first one. (Ideally, the underlying system would record that the two symbolic dims are meant to be the same ... so that it can be used to globally to use the same one for either one of them. We don't do such things at this point.)


# Replace unknown dims with uniquely named symbolic dims
assert merged_shape is not None
for i in range(len(merged_shape)):
if merged_shape.is_unknown_dim(i):
merged_shape[i] = ir.SymbolicDim(self._new_unknown_dim_name())

output.shape = merged_shape
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
Expand All @@ -1038,6 +1049,12 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def _new_unknown_dim_name(self) -> str:
"""Generate a new unique name for an unknown (None) symbolic dimension."""
name = f"unknown_{self._unknown_dim_count}"
self._unknown_dim_count += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The basic idea is useful. But this doesn't guarantee that the generated dim name will be unique (even though the likelihood of a conflict is low right now). Technically, we will need to identify all symbolic names used in the model first to check for conflicts (eg., like onnx does here)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can fix that part. Do you know have an idea why this would break SkipLayerNormFusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can fix that part. Do you know have an idea why this would break SkipLayerNormFusion?

Will take a look

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange ... was this working earlier? I noticed failures due to lack of shape information. That makes sense, since the input models don't have shape information. Specifically, these test-cases were generated by converting onnx models into onnxscript models: but the original examples did not include the shape information from original models into the onnxscript models (as here). Later on, the examples were extended so that we stored the value-infos from onnx models also into the corresponding onnxscript test-case (like here) ... but looks like we still have some older test-cases without shape information.

I added a call to shape-inference pass at the beginning of the test-case ... that mostly works, but it still fails in some edge-cases where it looks like we are unable to infer that some dim is "batch-size".

I am guessing if we update the test-cases to include shape information (as produced by exporters), it should work ... need to think best workaround

return name

def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
Expand Down
Loading