diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index ec0766c36..24f4b1339 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -724,9 +724,12 @@ def update_conditions_( update_conditions_(conditions, input_name, is_variadic, is_constant, shape, None) conditions[f"{input_name}_is_none"] = False + conditions["n_outputs"] = len(node.output) + # Try to derive properties, but catch errors for incomplete/invalid model information try: conditions = runtime_checker_op.derive_properties(conditions) + conditions.pop("n_outputs", None) except (KeyError, TypeError, IndexError) as e: # KeyError: missing required property (e.g., 'input_value', 'input_shape') # TypeError: invalid property value (e.g., None when expecting iterable) @@ -816,10 +819,13 @@ def _compute_dynamic_axes(shape: tuple | None, is_constant: bool) -> tuple[int, conditions[f"attr_{attr_name}"] = attr_value conditions[f"attr_{attr_name}_is_none"] = attr_value is None + conditions["n_outputs"] = len(pattern_match.skeleton_match_result.pattern.get_schema().outputs) + # Derive additional properties via pattern input generator if gen is not None: try: conditions = gen.derive_properties(conditions) + conditions.pop("n_outputs", None) infinite_properties = gen.get_infinite_property_names() except Exception as e: logger.debug("Could not derive properties for pattern '%s': %s", pattern_name, e) diff --git a/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py index e481b75fe..045cf678d 100644 --- a/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py @@ -818,9 +818,8 @@ def derive_properties(self, properties: dict) -> dict: split_array = np.array(item["attr_split"]) item["num_outputs"] = len(split_array) else: - raise ValueError( - "Shouldn't reach here: either split_value or attr_num_outputs must be present." - ) + # caller is get_query_conditions_for_*, "n_outputs" must be present + item["num_outputs"] = item["n_outputs"] # num_outputs may already be set from the combination return item diff --git a/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py b/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py index 0a14fc87e..1910ef982 100644 --- a/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py +++ b/src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py @@ -221,8 +221,20 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: # Normalize negative axes to positive for comparison normalized_axes = np.where(axes_array < 0, axes_array + data_dim, axes_array) - # Get dimension sizes for the sliced axes - axis_dims = np.array([data_shape[int(ax)] for ax in normalized_axes], dtype=np.int64) + + # Filter to axes whose dimensions are fixed (int, not string/symbolic) + fixed_mask = np.array( + [ + isinstance(data_shape[int(ax)], int) and data_shape[int(ax)] >= 0 + for ax in normalized_axes + ] + ) + + # Get dimension sizes only for fixed-shape axes + axis_dims = np.array( + [data_shape[int(ax)] if fixed_mask[i] else 0 for i, ax in enumerate(normalized_axes)], + dtype=np.int64, + ) # Derive starts-related properties starts_value = item.get("starts_value") @@ -231,7 +243,7 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: else: starts_array = np.array(starts_value, dtype=np.int64) - # Normalize negative starts: add axis_dim if negative + # Normalize negative starts: add axis_dim if negative (only meaningful for fixed axes) normalized_starts = np.where(starts_array < 0, starts_array + axis_dims, starts_array) # Derive ends-related properties @@ -241,16 +253,19 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]: else: ends_array = np.array(ends_value, dtype=np.int64) - # Normalize negative ends: add axis_dim if negative + # Normalize negative ends: add axis_dim if negative (only meaningful for fixed axes) normalized_ends = np.where(ends_array < 0, ends_array + axis_dims, ends_array) - # Check if starts are at the last element of each sliced axis - item["starts_equal_shape"] = bool(np.all(normalized_starts == axis_dims - 1)) - - # Check if this is a full slice (starts at 0, ends at dimension size) - item["slice_all"] = bool( - np.all(normalized_starts == 0) and np.all(normalized_ends >= axis_dims) - ) + # Check starts_equal_shape and slice_all only on fixed-shape axes + if np.any(fixed_mask): + fixed_starts = normalized_starts[fixed_mask] + fixed_ends = normalized_ends[fixed_mask] + fixed_dims = axis_dims[fixed_mask] + item["starts_equal_shape"] = bool(np.all(fixed_starts == fixed_dims - 1)) + item["slice_all"] = bool(np.all(fixed_starts == 0) and np.all(fixed_ends >= fixed_dims)) + else: + item["starts_equal_shape"] = False + item["slice_all"] = False # Derive steps-related properties steps_value = item.get("steps_value")