Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,12 @@ def update_conditions_(
update_conditions_(conditions, input_name, is_variadic, is_constant, shape, None)
conditions[f"{input_name}_is_none"] = False

conditions["n_outputs"] = len(node.output)

# Try to derive properties, but catch errors for incomplete/invalid model information
try:
conditions = runtime_checker_op.derive_properties(conditions)
conditions.pop("n_outputs", None)
except (KeyError, TypeError, IndexError) as e:
# KeyError: missing required property (e.g., 'input_value', 'input_shape')
# TypeError: invalid property value (e.g., None when expecting iterable)
Expand Down Expand Up @@ -816,10 +819,13 @@ def _compute_dynamic_axes(shape: tuple | None, is_constant: bool) -> tuple[int,
conditions[f"attr_{attr_name}"] = attr_value
conditions[f"attr_{attr_name}_is_none"] = attr_value is None

conditions["n_outputs"] = len(pattern_match.skeleton_match_result.pattern.get_schema().outputs)

# Derive additional properties via pattern input generator
if gen is not None:
try:
conditions = gen.derive_properties(conditions)
conditions.pop("n_outputs", None)
infinite_properties = gen.get_infinite_property_names()
except Exception as e:
logger.debug("Could not derive properties for pattern '%s': %s", pattern_name, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,9 +818,8 @@ def derive_properties(self, properties: dict) -> dict:
split_array = np.array(item["attr_split"])
item["num_outputs"] = len(split_array)
else:
raise ValueError(
"Shouldn't reach here: either split_value or attr_num_outputs must be present."
)
# caller is get_query_conditions_for_*, "n_outputs" must be present
item["num_outputs"] = item["n_outputs"]
# num_outputs may already be set from the combination
return item

Expand Down
37 changes: 26 additions & 11 deletions src/winml/modelkit/pattern/op_input_gen/slice_input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,20 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]:

# Normalize negative axes to positive for comparison
normalized_axes = np.where(axes_array < 0, axes_array + data_dim, axes_array)
# Get dimension sizes for the sliced axes
axis_dims = np.array([data_shape[int(ax)] for ax in normalized_axes], dtype=np.int64)

# Filter to axes whose dimensions are fixed (int, not string/symbolic)
fixed_mask = np.array(
[
isinstance(data_shape[int(ax)], int) and data_shape[int(ax)] >= 0
for ax in normalized_axes
]
)

# Get dimension sizes only for fixed-shape axes
axis_dims = np.array(
[data_shape[int(ax)] if fixed_mask[i] else 0 for i, ax in enumerate(normalized_axes)],
dtype=np.int64,
)

# Derive starts-related properties
starts_value = item.get("starts_value")
Expand All @@ -231,7 +243,7 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
else:
starts_array = np.array(starts_value, dtype=np.int64)

# Normalize negative starts: add axis_dim if negative
# Normalize negative starts: add axis_dim if negative (only meaningful for fixed axes)
normalized_starts = np.where(starts_array < 0, starts_array + axis_dims, starts_array)

# Derive ends-related properties
Expand All @@ -241,16 +253,19 @@ def derive_properties(self, properties: dict[str, Any]) -> dict[str, Any]:
else:
ends_array = np.array(ends_value, dtype=np.int64)

# Normalize negative ends: add axis_dim if negative
# Normalize negative ends: add axis_dim if negative (only meaningful for fixed axes)
normalized_ends = np.where(ends_array < 0, ends_array + axis_dims, ends_array)

# Check if starts are at the last element of each sliced axis
item["starts_equal_shape"] = bool(np.all(normalized_starts == axis_dims - 1))

# Check if this is a full slice (starts at 0, ends at dimension size)
item["slice_all"] = bool(
np.all(normalized_starts == 0) and np.all(normalized_ends >= axis_dims)
)
# Check starts_equal_shape and slice_all only on fixed-shape axes
if np.any(fixed_mask):
fixed_starts = normalized_starts[fixed_mask]
fixed_ends = normalized_ends[fixed_mask]
fixed_dims = axis_dims[fixed_mask]
item["starts_equal_shape"] = bool(np.all(fixed_starts == fixed_dims - 1))
item["slice_all"] = bool(np.all(fixed_starts == 0) and np.all(fixed_ends >= fixed_dims))
else:
item["starts_equal_shape"] = False
item["slice_all"] = False

# Derive steps-related properties
steps_value = item.get("steps_value")
Expand Down
Loading