diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 9a740c783..853436fce 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -974,6 +974,8 @@ def __init__( self._sizes: dict[str, int] = {} self._modified: bool = False self._state = OptimizerState() + # Count of unknown (None) symbolic dimensions seen so far for generating unique names + self._unknown_dim_count = 0 self._reset() def _reset(self) -> None: @@ -982,6 +984,7 @@ def _reset(self) -> None: self._sizes = {} self._modified = False self._state = OptimizerState() + self._unknown_dim_count = 0 def _do_inference(self, node: ir.Node) -> None: output_types = {} @@ -1029,7 +1032,15 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: inferred_shape = ir.serde.deserialize_type_proto_for_shape( inferred_type ) - output.shape = _merge_shapes(output.shape, inferred_shape) + merged_shape = _merge_shapes(output.shape, inferred_shape) + + # Replace unknown dims with uniquely named symbolic dims + assert merged_shape is not None + for i in range(len(merged_shape)): + if merged_shape.is_unknown_dim(i): + merged_shape[i] = ir.SymbolicDim(self._new_unknown_dim_name()) + + output.shape = merged_shape output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( @@ -1038,6 +1049,12 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) + def _new_unknown_dim_name(self) -> str: + """Generate a new unique name for an unknown (None) symbolic dimension.""" + name = f"unknown_{self._unknown_dim_count}" + self._unknown_dim_count += 1 + return name + def new_constant(self, node: ir.Node, value) -> ir.Node | None: irvalue = node.outputs[0] if not isinstance(value, np.ndarray):