diff --git a/src/op_system/__init__.py b/src/op_system/__init__.py index a8ff0b8..e915b6e 100644 --- a/src/op_system/__init__.py +++ b/src/op_system/__init__.py @@ -85,6 +85,16 @@ def compile_spec( # noqa: RUF067 Returns: CompiledRhs: Runnable RHS callable container. + + Examples: + >>> import numpy as np + >>> compiled = compile_spec({ + ... "kind": "expr", + ... "state": ["x"], + ... "equations": {"x": "-x"}, + ... }) + >>> compiled.eval_fn(0.0, np.array([1.0])) + array([-1.]) """ if xp is not None or backend != DEFAULT_ARRAY_BACKEND: warnings.warn( diff --git a/src/op_system/_axes.py b/src/op_system/_axes.py index 67f3775..498b407 100644 --- a/src/op_system/_axes.py +++ b/src/op_system/_axes.py @@ -53,6 +53,20 @@ def _normalize_bracket_key(key: str) -> str: def _normalize_axis_name(ax_map: Mapping[str, Any], *, idx: int, seen: set[str]) -> str: + """Validate and return the ``name`` field of one axis definition. + + Args: + ax_map: Raw axis mapping. + idx: Position in the surrounding ``axes`` list (for diagnostics). + seen: Mutable set of already-registered axis names; updated in place. + + Returns: + The validated, stripped axis name. + + Raises: + InvalidRhsSpecError: If ``name`` is missing, not a string, empty, or + duplicates an earlier axis. + """ name_val = ax_map.get("name") if not isinstance(name_val, str) or not name_val.strip(): raise InvalidRhsSpecError(detail=f"axes[{idx}].name must be a non-empty string") @@ -64,6 +78,18 @@ def _normalize_axis_name(ax_map: Mapping[str, Any], *, idx: int, seen: set[str]) def _normalize_axis_type(ax_map: Mapping[str, Any], *, idx: int) -> str: + """Validate and return the ``type`` field of one axis (default categorical). + + Args: + ax_map: Raw axis mapping. + idx: Position in the surrounding ``axes`` list (for diagnostics). + + Returns: + One of ``"categorical"``, ``"ordinal"``, or ``"continuous"``. + + Raises: + InvalidRhsSpecError: If ``type`` is set to anything else. + """ ax_type = str(ax_map.get("type", "categorical")).strip().lower() if ax_type not in {"categorical", "ordinal", "continuous"}: raise InvalidRhsSpecError( @@ -75,6 +101,19 @@ def _normalize_axis_type(ax_map: Mapping[str, Any], *, idx: int) -> str: def _normalize_axis_units(ax_map: Mapping[str, Any], *, idx: int) -> str | None: + """Validate and return the optional ``units`` field of one axis. + + Args: + ax_map: Raw axis mapping. + idx: Position in the surrounding ``axes`` list (for diagnostics). + + Returns: + Stripped units string or ``None`` when absent. + + Raises: + InvalidRhsSpecError: If ``units`` is provided but is not a non-empty + string. + """ units_obj = ax_map.get("units") if units_obj is None: return None @@ -91,6 +130,23 @@ def _normalize_axis_coords( idx: int, ax_type: str, ) -> tuple[list[Any], int]: + """Validate explicit ``coords`` for one axis. + + Categorical and ordinal axes must have non-empty unique string coords; + continuous axes coerce values to numbers and require monotonic + non-decreasing order. + + Args: + coords_obj: Raw value of ``coords``. + idx: Position in the surrounding ``axes`` list (for diagnostics). + ax_type: Already-validated axis type. + + Returns: + ``(coords, size)`` pair. + + Raises: + InvalidRhsSpecError: If validation fails. + """ if not isinstance(coords_obj, (list, tuple)) or not coords_obj: raise InvalidRhsSpecError(detail=f"axes[{idx}].coords must be a non-empty list") coords = list(coords_obj) @@ -164,6 +220,20 @@ def _compute_axis_deltas(coords: list[float], *, idx: int) -> list[float]: def _generate_continuous_coords( *, domain: object, size_obj: object, spacing: str, idx: int ) -> tuple[list[float], int]: + """Generate ``coords`` for a continuous axis from ``domain``/``size``/``spacing``. + + Args: + domain: Raw ``domain`` mapping with ``lb``/``ub``. + size_obj: Raw ``size`` value (must be an integer >= 2). + spacing: One of ``"linear"``, ``"log"``, ``"geom"``. + idx: Position in the surrounding ``axes`` list (for diagnostics). + + Returns: + ``(coords, size)`` pair where ``coords`` has length ``size``. + + Raises: + InvalidRhsSpecError: On invalid bounds, size, or spacing. + """ domain_map = ( _ensure_mapping(domain, name=f"axes[{idx}].domain") if domain is not None @@ -221,6 +291,21 @@ def _generate_continuous_coords( def _normalize_single_axis( ax_map: Mapping[str, Any], *, idx: int, seen: set[str] ) -> dict[str, Any]: + """Normalize one axis mapping into the canonical record. + + Args: + ax_map: Raw axis mapping. + idx: Position in the surrounding ``axes`` list (for diagnostics). + seen: Mutable set of already-registered axis names; updated in place. + + Returns: + Canonical axis dict (``name``, ``type``, ``coords``, ``size``, + and optionally ``deltas``, ``domain``, ``spacing``, ``units``). + + Raises: + InvalidRhsSpecError: If the axis is categorical or ordinal but + has no ``coords`` field. + """ name = _normalize_axis_name(ax_map, idx=idx, seen=seen) ax_type = _normalize_axis_type(ax_map, idx=idx) spacing = str(ax_map.get("spacing", "linear")).strip().lower() diff --git a/src/op_system/_errors.py b/src/op_system/_errors.py index 8e4f15b..01808fe 100644 --- a/src/op_system/_errors.py +++ b/src/op_system/_errors.py @@ -29,6 +29,13 @@ def __init__( missing: list[str] | None = None, detail: str | None = None, ) -> None: + """Initialize the error. + + Args: + missing: Optional list of missing required field names. + detail: Optional human-readable detail describing the + violation. + """ self.missing = list(missing) if missing else None self.detail = detail parts: list[str] = [_INVALID_RHS_SPEC_PREFIX] @@ -47,6 +54,12 @@ class InvalidExpressionError(ValueError): """ def __init__(self, *, detail: str) -> None: + """Initialize the error. + + Args: + detail: Human-readable detail describing the parse or + validation failure. + """ self.detail = detail super().__init__(f"{_INVALID_EXPRESSION_PREFIX} Detail: {detail}") @@ -60,6 +73,12 @@ class UnsupportedFeatureError(NotImplementedError): """ def __init__(self, *, feature: str, detail: str | None = None) -> None: + """Initialize the error. + + Args: + feature: Identifier for the unsupported feature. + detail: Optional additional detail. + """ self.feature = feature self.detail = detail msg = f"{_UNSUPPORTED_FEATURE_PREFIX} Feature {feature!r} is not supported." diff --git a/src/op_system/_ir.py b/src/op_system/_ir.py index 89a9add..538a654 100644 --- a/src/op_system/_ir.py +++ b/src/op_system/_ir.py @@ -139,16 +139,41 @@ class Reduce: def _invalid(*, detail: str) -> NoReturn: + """Raise :class:`InvalidExpressionError` with ``detail``. + + Args: + detail: Human-readable detail describing the violation. + + Raises: + InvalidExpressionError: Always. + """ raise InvalidExpressionError(detail=detail) def _kwarg_name(arg: Expr) -> str: + """Return the string key carried by a ``kwarg`` marker's first arg. + + Args: + arg: First argument of an ``Apply(op="kwarg", ...)`` node. + + Returns: + The string key. + """ if isinstance(arg, Literal) and isinstance(arg.value, str): return arg.value _invalid(detail="kwarg marker must use a string key") def _kwarg_value_as_binding(value: Expr, *, key: str) -> str: + """Render an ``apply_along`` kwarg value as an axis-binding identifier. + + Args: + value: IR value attached to the kwarg. + key: Kwarg key (used only for error reporting). + + Returns: + Bare identifier or stringified literal. + """ if isinstance(value, Sym): return value.name if isinstance(value, Literal): @@ -190,6 +215,22 @@ def _extract_filter_from_value(value: Expr) -> tuple[str, tuple[str, ...]] | Non def _lower_single_helper(node: Apply) -> Expr: # noqa: C901, PLR0912 + """Lower one helper :class:`Apply` node into a structured :class:`Reduce`. + + Helper ops in :data:`_HELPER_REDUCE_OPS` (``sum``, ``integrate``, + ``apply_along``) carry positional bodies and ``kwarg`` markers. This + routine partitions the args, parses ``axis=binding`` and ``axis=var in + [...]`` filter forms, and returns a :class:`Reduce` node with + ``bindings``, ``filters`` and optional ``kernel`` populated. + + Non-helper :class:`Apply` nodes are returned unchanged. + + Args: + node: An :class:`Apply` node to lower. + + Returns: + Lowered :class:`Reduce` node, or ``node`` if it is not a helper call. + """ if node.op not in _HELPER_REDUCE_OPS: return node @@ -252,6 +293,11 @@ def _lower_single_helper(node: Apply) -> Expr: # noqa: C901, PLR0912 def _validate_reduce(reduce_node: Reduce) -> None: + """Validate the structural invariants of a :class:`Reduce` node. + + Args: + reduce_node: The candidate node to validate. + """ if not reduce_node.bindings: _invalid(detail=(f"{reduce_node.kind} requires at least one axis=var binding")) @@ -283,6 +329,14 @@ def lower_helper_calls(expr: Expr) -> Expr: def _call_name(func: ast.AST) -> str: + """Render a Python AST call target as a dotted name. + + Args: + func: ``func`` slot of an :class:`ast.Call` node. + + Returns: + Dotted identifier (e.g. ``"np.sum"`` or ``"sum"``). + """ if isinstance(func, ast.Name): return func.id if isinstance(func, ast.Attribute): @@ -298,6 +352,17 @@ def _call_name(func: ast.AST) -> str: def _axis_index_from_expr(node: ast.expr) -> AxisIndex: + """Build an :class:`AxisIndex` from a non-slice subscript element AST. + + Bare names map to ``AxisIndex(axis=name)``; constants map to a + coord-only ``AxisIndex(axis="", coord=str(value))``. + + Args: + node: Subscript element AST node. + + Returns: + Parsed :class:`AxisIndex`. + """ if isinstance(node, ast.Name): return AxisIndex(axis=node.id) if isinstance(node, ast.Constant) and isinstance( @@ -308,6 +373,14 @@ def _axis_index_from_expr(node: ast.expr) -> AxisIndex: def _bound_axis_index_from_slice(node: ast.Slice) -> AxisIndex: + """Build an :class:`AxisIndex` from a ``axis:binding`` slice node. + + Args: + node: ``ast.Slice`` of the form ``axis:binding``. + + Returns: + :class:`AxisIndex` carrying both ``axis`` and ``coord``. + """ if node.step is not None: _invalid(detail="bound-axis subscript axis:binding must not have a step") if node.lower is None or node.upper is None: @@ -318,12 +391,30 @@ def _bound_axis_index_from_slice(node: ast.Slice) -> AxisIndex: def _parse_subscript_index_element(node: ast.expr) -> AxisIndex: + """Parse a single subscript element (slice or expression) to :class:`AxisIndex`. + + Args: + node: Subscript element AST node. + + Returns: + Parsed :class:`AxisIndex`. + """ if isinstance(node, ast.Slice): return _bound_axis_index_from_slice(node) return _axis_index_from_expr(node) def _parse_subscript_indices(slc: ast.expr) -> tuple[AxisIndex, ...]: + """Parse a subscript slot into a tuple of :class:`AxisIndex` entries. + + Handles both single-element (``a[i]``) and tuple (``a[i, j]``) forms. + + Args: + slc: ``slice`` slot of an ``ast.Subscript`` node. + + Returns: + Tuple of parsed :class:`AxisIndex` entries. + """ if isinstance(slc, ast.Tuple): return tuple(_parse_subscript_index_element(elt) for elt in slc.elts) return (_parse_subscript_index_element(slc),) @@ -458,6 +549,14 @@ def parse_expr_to_ir(expr: str, *, lower_helpers: bool = False) -> Expr: def _literal_coord_value(coord: str) -> int | float | str: + """Cast a coord token to ``int``/``float`` if numeric, else return the string. + + Args: + coord: Coord token (string form). + + Returns: + ``int``, ``float`` or ``str`` value depending on parseability. + """ try: return int(coord) except ValueError: @@ -468,12 +567,31 @@ def _literal_coord_value(coord: str) -> int | float | str: def _axis_index_to_ast(idx: AxisIndex) -> ast.expr: + """Render an :class:`AxisIndex` to a Python AST expression node. + + Coord-bearing indices become typed ``ast.Constant`` nodes; bare-axis + indices become ``ast.Name`` nodes. + + Args: + idx: Index to render. + + Returns: + AST expression node. + """ if idx.coord is not None: return ast.Constant(value=_literal_coord_value(idx.coord)) return ast.Name(id=idx.axis, ctx=ast.Load()) def _call_func_ast(name: str) -> ast.expr: + """Build the ``func`` slot of an ``ast.Call`` from a dotted name string. + + Args: + name: Dotted identifier (e.g. ``"np.sum"``). + + Returns: + AST expression node suitable for ``ast.Call.func``. + """ parts = name.split(".") node: ast.expr = ast.Name(id=parts[0], ctx=ast.Load()) for part in parts[1:]: @@ -482,6 +600,14 @@ def _call_func_ast(name: str) -> ast.expr: def _binary_ast(op: str) -> ast.operator: + """Translate an IR binary op token to its ``ast.operator`` instance. + + Args: + op: IR op string (``"+"``, ``"*"``, ``"pow"``, ...). + + Returns: + Concrete :class:`ast.operator` subclass instance. + """ if op == "+": return ast.Add() if op == "-": @@ -498,6 +624,14 @@ def _binary_ast(op: str) -> ast.operator: def _compare_ast(op: str) -> ast.cmpop: + """Translate an IR comparison op token to its ``ast.cmpop`` instance. + + Args: + op: Comparison op string (``"=="``, ``"<"``, ...). + + Returns: + Concrete :class:`ast.cmpop` subclass instance. + """ if op == "==": return ast.Eq() if op == "!=": @@ -625,6 +759,15 @@ def _render_coord(coord: object) -> str: # noqa: PLR0911 def _unparse_axis_index(idx: AxisIndex) -> str: + """Render an :class:`AxisIndex` back to its source-string form. + + Args: + idx: Index to render. + + Returns: + ``"axis:coord"``, ``"coord"`` or ``"axis"`` depending on which + fields are populated. + """ if idx.coord is not None: rendered = _render_coord(idx.coord) if idx.axis and idx.axis != idx.coord: @@ -634,6 +777,16 @@ def _unparse_axis_index(idx: AxisIndex) -> str: def _unparse_call_args(args: tuple[Expr, ...]) -> str: + """Render the arg list of an :class:`Apply` (call) back to source form. + + Handles ``kwarg`` marker nodes by emitting ``key=value`` syntax. + + Args: + args: Apply argument tuple. + + Returns: + Comma-separated argument source string. + """ parts: list[str] = [] for arg in args: if isinstance(arg, Apply) and arg.op == "kwarg": @@ -691,6 +844,14 @@ def unparse_ir(expr: Expr) -> str: def _expr_precedence(expr: Expr) -> int: + """Return the unparser precedence level for ``expr``. + + Args: + expr: IR expression. + + Returns: + Integer precedence used by the unparser to decide on parentheses. + """ if isinstance(expr, Apply): if expr.op in {"neg", "pos"}: return _PREC_UNARY @@ -705,6 +866,15 @@ def _expr_precedence(expr: Expr) -> int: def _wrap(text: str, *, need: bool) -> str: + """Optionally parenthesize ``text``. + + Args: + text: Already-rendered source string. + need: Whether parentheses are required. + + Returns: + Either ``"(text)"`` or ``text``. + """ return f"({text})" if need else text @@ -716,6 +886,18 @@ def _unparse_binary( parent_prec: int, is_right: bool, ) -> str: + """Render a binary-op :class:`Apply` to source, folding multi-arg chains. + + Args: + expr: Apply node carrying two-or-more operands. + op: Operator token to render between operands. + prec: Precedence level of ``op``. + parent_prec: Precedence level of the enclosing context. + is_right: Whether this expression is the right operand of its parent. + + Returns: + Source string for the binary expression, parenthesized if needed. + """ sep = f" {op} " # Fold args left-associatively so multi-arg flatten still renders correctly. args = expr.args @@ -736,6 +918,16 @@ def _unparse_ir( # noqa: C901, PLR0911 parent_prec: int, is_right: bool, ) -> str: + """Recursive worker for :func:`unparse_ir`. + + Args: + expr: IR expression to render. + parent_prec: Precedence level of the enclosing context. + is_right: Whether ``expr`` sits on the right of its parent operator. + + Returns: + Source string equivalent to ``expr``. + """ if isinstance(expr, Literal): return repr(expr.value) @@ -998,6 +1190,18 @@ def substitute( def _map_children(expr: Expr, fn: Callable[[Expr], Expr]) -> Expr: + """Apply ``fn`` to each immediate child and return a rebuilt node. + + Returns the original ``expr`` if no child changed (preserving identity + so callers can use ``is`` checks for short-circuiting). + + Args: + expr: IR expression. + fn: Per-child rewrite function. + + Returns: + Possibly-rewritten IR expression. + """ if isinstance(expr, Apply): new_args = tuple(fn(arg) for arg in expr.args) if new_args == expr.args: @@ -1018,10 +1222,29 @@ def _map_children(expr: Expr, fn: Callable[[Expr], Expr]) -> Expr: def _is_cse_candidate(expr: Expr) -> bool: + """Return ``True`` when ``expr`` is eligible for CSE extraction. + + Args: + expr: IR expression. + + Returns: + ``True`` for :class:`Apply` and :class:`Reduce` nodes. + """ return isinstance(expr, (Apply, Reduce)) def _expr_cost(expr: Expr) -> int: + """Return a structural cost estimate for ``expr``. + + The cost is simply ``1`` per node so that subtrees can be ranked by + size when deciding which repeated subtrees are worth extracting. + + Args: + expr: IR expression. + + Returns: + Integer node count. + """ if isinstance(expr, Apply): return 1 + sum(_expr_cost(arg) for arg in expr.args) if isinstance(expr, Reduce): @@ -1030,6 +1253,12 @@ def _expr_cost(expr: Expr) -> int: def _postorder(expr: Expr, out: dict[Expr, int]) -> None: + """Post-order walk that tallies subtree occurrences into ``out``. + + Args: + expr: IR expression to walk. + out: Mutable counter mapping; incremented in-place per node. + """ if isinstance(expr, Apply): for arg in expr.args: _postorder(arg, out) @@ -1115,6 +1344,18 @@ def axis_kinds(expr: Expr, *, axis_names: frozenset[str]) -> tuple[AxisKind, ... def _resolve_index(idx: AxisIndex, *, axis_names: frozenset[str]) -> AxisIndex: + """Return ``idx`` with its ``kind`` field populated. + + Reuses ``idx`` when its ``kind`` already matches the classification + against ``axis_names``, so resolved trees are idempotent. + + Args: + idx: Subscript index. + axis_names: Set of registered axis identifier strings. + + Returns: + :class:`AxisIndex` with ``kind`` populated. + """ kind = classify_axis_index(idx, axis_names=axis_names) if idx.kind is kind: return idx diff --git a/src/op_system/_ir_lower.py b/src/op_system/_ir_lower.py index c539da4..20726fd 100644 --- a/src/op_system/_ir_lower.py +++ b/src/op_system/_ir_lower.py @@ -90,6 +90,14 @@ class UnsupportedIRLoweringError(NotImplementedError): def _name(ident: str) -> ast.Name: + """Build an ``ast.Name`` load node for ``ident``. + + Args: + ident: Identifier text. + + Returns: + ``ast.Name`` node with ``ctx=Load()``. + """ return ast.Name(id=ident, ctx=ast.Load()) @@ -98,47 +106,32 @@ def _name(ident: str) -> ast.Name: # --------------------------------------------------------------------------- -def lower_subscript_to_buffer( # noqa: C901 +def _validate_wildcard_subscript_axes( sub: Subscript, *, src_axes: tuple[str, ...], target_axes: tuple[str, ...], axis_names: frozenset[str], - axis_alias: Mapping[str, str] | None = None, -) -> ast.expr: - """Lower a wildcard IR :class:`Subscript` to a buffer-access AST. - - The returned AST evaluates to an array that broadcasts cleanly against - a tensor of shape implied by ``target_axes`` — i.e. it carries a - singleton (``None``) dimension for every axis in ``target_axes`` not - present in the subscript and is transposed so kept axes appear in - ``target_axes`` order. + axis_alias: Mapping[str, str], +) -> tuple[list[str], list[str]]: + """Validate a wildcard subscript and return its (synthetic, real) axes. Args: - sub: IR subscript to lower. Must have one :class:`AxisKind.FREE` - index per element, and its axis set must equal ``src_axes`` as - a set (any permutation accepted). - src_axes: Declared axis order of the source buffer - (the ordering used to flatten ``_buf``). - target_axes: Axis order of the cell layout the result must - broadcast against. - axis_names: Registered axis identifiers; used to classify indices. - axis_alias: Optional mapping from synthetic axis labels (e.g. - ``age#ap`` emitted by :func:`_lower_reduce` for same-axis-twice - bindings) back to their real axis name. Used to match - subscript labels against ``src_axes`` while preserving the - synthetic label for ``target_axes`` alignment. + sub: IR subscript whose indices must all be FREE. + src_axes: Declared source-buffer axis order. + target_axes: Cell-layout axis order. + axis_names: Registered axis identifiers. + axis_alias: Mapping from synthetic to real axis labels. Returns: - An ``ast.expr`` accessing ``_buf`` with the necessary - transpose / size-1 insertions to align with ``target_axes``. + Pair ``(sub_axes, real_sub_axes)`` of synthetic and real axis + labels in the order they appear in ``sub.indices``. Raises: - UnsupportedIRLoweringError: If any index is non-FREE, the index axis - set doesn't match ``src_axes``, or ``src_axes`` is not a - subset of ``target_axes``. + UnsupportedIRLoweringError: If index/axis count mismatch, any + index is not FREE, the index axes don't match ``src_axes`` as + a set, or ``src_axes`` is not a subset of ``target_axes``. """ - alias = axis_alias or {} if len(sub.indices) != len(src_axes): msg = ( f"subscript {sub.name!r} has {len(sub.indices)} indices but " @@ -158,7 +151,7 @@ def lower_subscript_to_buffer( # noqa: C901 ) raise UnsupportedIRLoweringError(msg) sub_axes.append(idx.axis) - real_sub_axes.append(alias.get(idx.axis, idx.axis)) + real_sub_axes.append(axis_alias.get(idx.axis, idx.axis)) if set(real_sub_axes) != set(src_axes): msg = ( @@ -174,6 +167,53 @@ def lower_subscript_to_buffer( # noqa: C901 ) raise UnsupportedIRLoweringError(msg) + return sub_axes, real_sub_axes + + +def lower_subscript_to_buffer( + sub: Subscript, + *, + src_axes: tuple[str, ...], + target_axes: tuple[str, ...], + axis_names: frozenset[str], + axis_alias: Mapping[str, str] | None = None, +) -> ast.expr: + """Lower a wildcard IR :class:`Subscript` to a buffer-access AST. + + The returned AST evaluates to an array that broadcasts cleanly against + a tensor of shape implied by ``target_axes`` — i.e. it carries a + singleton (``None``) dimension for every axis in ``target_axes`` not + present in the subscript and is transposed so kept axes appear in + ``target_axes`` order. + + Args: + sub: IR subscript to lower. Must have one :class:`AxisKind.FREE` + index per element, and its axis set must equal ``src_axes`` as + a set (any permutation accepted). + src_axes: Declared axis order of the source buffer + (the ordering used to flatten ``_buf``). + target_axes: Axis order of the cell layout the result must + broadcast against. + axis_names: Registered axis identifiers; used to classify indices. + axis_alias: Optional mapping from synthetic axis labels (e.g. + ``age#ap`` emitted by :func:`_lower_reduce` for same-axis-twice + bindings) back to their real axis name. Used to match + subscript labels against ``src_axes`` while preserving the + synthetic label for ``target_axes`` alignment. + + Returns: + An ``ast.expr`` accessing ``_buf`` with the necessary + transpose / size-1 insertions to align with ``target_axes``. + """ + alias = axis_alias or {} + sub_axes, real_sub_axes = _validate_wildcard_subscript_axes( + sub, + src_axes=src_axes, + target_axes=target_axes, + axis_names=axis_names, + axis_alias=alias, + ) + buf: ast.expr = _name(f"{sub.name}_buf") # Reorder buffer (stored in src_axes order) to sub_axes order if the @@ -214,6 +254,15 @@ def lower_subscript_to_buffer( # noqa: C901 def _transpose(node: ast.expr, perm: tuple[int, ...]) -> ast.expr: + """Wrap ``node`` in a ``np.transpose(node, perm)`` call AST. + + Args: + node: AST expression that evaluates to an array. + perm: Permutation tuple to pass as the second argument. + + Returns: + ``ast.Call`` node representing ``np.transpose(node, perm)``. + """ return ast.Call( func=ast.Attribute(value=_name("np"), attr="transpose", ctx=ast.Load()), args=[ @@ -579,7 +628,49 @@ def _resolve_ordinal_range_filter( return tuple(range(lo_idx, hi_idx + 1)) -def _resolve_continuous_range_filter( # noqa: C901 +def _trapezoidal_weights_or_raise( + sub_floats: list[float], *, axis: str +) -> tuple[float, ...]: + """Compute trapezoidal weights for a strictly increasing coord list. + + Args: + sub_floats: Numeric coords selected by a continuous-axis filter, + in declared order. + axis: Axis name, used in diagnostics. + + Returns: + Per-coord trapezoidal weights as a tuple. + + Raises: + UnsupportedIRLoweringError: If fewer than two coords are provided + or the coords are not strictly increasing. + """ + if len(sub_floats) < 2: + msg = ( + f"Reduce filter for continuous axis {axis!r} sub-interval needs " + "at least 2 coords for trapezoidal integration" + ) + raise UnsupportedIRLoweringError(msg) + sub_weights: list[float] = [] + for i in range(len(sub_floats)): + if i == 0: + width = (sub_floats[1] - sub_floats[0]) / 2.0 + elif i == len(sub_floats) - 1: + width = (sub_floats[-1] - sub_floats[-2]) / 2.0 + else: + width = (sub_floats[i + 1] - sub_floats[i - 1]) / 2.0 + if width <= 0.0: + msg = ( + f"Reduce filter for continuous axis {axis!r} sub-interval " + "coords must be strictly increasing for trapezoidal " + "integration" + ) + raise UnsupportedIRLoweringError(msg) + sub_weights.append(width) + return tuple(sub_weights) + + +def _resolve_continuous_range_filter( axis: str, declared: tuple[str, ...], filt: tuple[str, ...], @@ -643,29 +734,7 @@ def _resolve_continuous_range_filter( # noqa: C901 if not recompute_trapezoidal: return indices, None sub_floats = [coord_floats[i] for i in indices] - if len(sub_floats) < 2: - msg = ( - f"Reduce filter for continuous axis {axis!r} sub-interval needs " - "at least 2 coords for trapezoidal integration" - ) - raise UnsupportedIRLoweringError(msg) - sub_weights: list[float] = [] - for i in range(len(sub_floats)): - if i == 0: - width = (sub_floats[1] - sub_floats[0]) / 2.0 - elif i == len(sub_floats) - 1: - width = (sub_floats[-1] - sub_floats[-2]) / 2.0 - else: - width = (sub_floats[i + 1] - sub_floats[i - 1]) / 2.0 - if width <= 0.0: - msg = ( - f"Reduce filter for continuous axis {axis!r} sub-interval " - "coords must be strictly increasing for trapezoidal " - "integration" - ) - raise UnsupportedIRLoweringError(msg) - sub_weights.append(width) - return indices, tuple(sub_weights) + return indices, _trapezoidal_weights_or_raise(sub_floats, axis=axis) def _lower_reduce( # noqa: C901, PLR0912, PLR0913, PLR0914, PLR0915 @@ -1062,7 +1131,43 @@ def _lower_apply( # noqa: PLR0913 shaped_param_axes: Mapping[str, tuple[str, ...]] | None = None, axis_alias: Mapping[str, str] | None = None, ) -> ast.expr: + """Lower an :class:`Apply` IR node to its NumPy-AST equivalent. + + Recursive worker for :func:`lower_to_vector_ast`. Handles unary, + binary, comparison, ternary and helper-call shapes by delegating to + :func:`lower_to_vector_ast` for each child. + + The keyword arguments mirror :func:`lower_to_vector_ast` exactly; see + that function for full descriptions. + + Args: + expr: Apply node to lower. + target_axes: Cell-layout axis order the result must broadcast in. + buffer_axes: Map from state-template name to its declared axes. + axis_names: Set of registered axis identifiers. + reducible_axes: Set of axis names that ``sum``/``integrate`` + helpers may reduce over. + axis_weights: Per-axis trapezoidal-rule weights for ``integrate``. + axis_coords: Per-axis coord lists used by filter helpers. + axis_types: Per-axis type tags (``"categorical"``, ``"ordinal"``, + ``"continuous"``). + shaped_param_axes: Map from shaped-param name to its declared axes. + axis_alias: Optional alias from synthetic to real axis labels. + + Returns: + Lowered AST expression. + + Raises: + UnsupportedIRLoweringError: If ``expr`` has a shape not handled by + this lowerer. + """ + def lower(child: Expr) -> ast.expr: + """Lower a child node with the same surrounding configuration. + + Returns: + Lowered AST expression for ``child``. + """ return lower_to_vector_ast( child, target_axes=target_axes, diff --git a/src/op_system/_ir_templates.py b/src/op_system/_ir_templates.py index b42ca3e..be6ef1e 100644 --- a/src/op_system/_ir_templates.py +++ b/src/op_system/_ir_templates.py @@ -91,6 +91,25 @@ def _expand_subscript( shaped_params: Mapping[str, tuple[str, ...]], axis_lookup: Mapping[str, Sequence[str]] | None, ) -> Expr: + """Expand one :class:`Subscript` against ``assignment``. + + Shaped-parameter subscripts (those whose base name is in + ``shaped_params``) are rewritten to integer-coord :class:`AxisIndex` + entries; non-shaped subscripts collapse to a templated scalar + :class:`Sym` whose name is rendered from the base and assignment. + + Args: + sub: Subscript node to expand. + assignment: Map from axis placeholder name to concrete coord. + shaped_params: Shaped-parameter axis registry. + axis_lookup: Map from axis name to ordered coord list, required to + resolve shaped indices. + + Returns: + Either an updated :class:`Subscript` (shaped path), a scalar + :class:`Sym` (template path), or ``sub`` if expansion is not + applicable (e.g. literal-coord indices). + """ axes = _expandable_axes(sub.indices) if axes is None or not axes: return sub @@ -205,6 +224,17 @@ def _collect_subscript_axes( shaped: Mapping[str, tuple[str, ...]], out: dict[str, None], ) -> None: + """Record bare-axis names from one :class:`Subscript` into ``out``. + + Skips shaped-parameter subscripts (their indices resolve to fixed + integer positions) and literal-coord indices. + + Args: + sub: Subscript node to inspect. + shaped: Shaped-parameter registry. + out: Insertion-ordered set (modeled as ``dict[str, None]``); + updated in place. + """ if sub.name in shaped: return for idx in sub.indices: @@ -220,6 +250,17 @@ def _walk_for_axes( shaped: Mapping[str, tuple[str, ...]], seen: dict[str, None], ) -> None: + """Walk ``node`` and accumulate free axis names into ``seen``. + + :class:`Reduce` bindings shadow outer names: any axis name bound by a + nested :class:`Reduce` is removed from ``seen`` after walking the + body, so the result reflects the *free* axes only. + + Args: + node: IR expression to walk. + shaped: Shaped-parameter registry. + seen: Insertion-ordered set; updated in place. + """ if isinstance(node, (Literal, Sym)): return if isinstance(node, Subscript): @@ -323,6 +364,11 @@ def _detect_alias_cycle(aliases: Mapping[str, Expr]) -> list[str] | None: stack: list[str] = [] def visit(node: str) -> list[str] | None: + """DFS visitor; returns the cycle path if a back edge is found. + + Returns: + List of alias names forming a cycle, or ``None`` if none. + """ color[node] = _CYCLE_GREY stack.append(node) for nxt in graph[node]: diff --git a/src/op_system/_normalize.py b/src/op_system/_normalize.py index f1ae39b..e834d57 100644 --- a/src/op_system/_normalize.py +++ b/src/op_system/_normalize.py @@ -263,6 +263,18 @@ def normalize_rhs(spec: Mapping[str, Any] | None) -> NormalizedRhs: Raises: InvalidRhsSpecError: If validation fails. UnsupportedFeatureError: If validation fails. + + Examples: + >>> spec = { + ... "kind": "expr", + ... "state": ["x"], + ... "equations": {"x": "2.0 * x"}, + ... } + >>> rhs = normalize_rhs(spec) + >>> isinstance(rhs, ExprRhs) + True + >>> rhs.state_names + ('x',) """ if spec is None: raise InvalidRhsSpecError(detail="rhs specification is required") @@ -292,6 +304,17 @@ def normalize_expr_rhs(spec: Mapping[str, Any]) -> ExprRhs: # noqa: C901, PLR09 Raises: InvalidRhsSpecError: If validation fails. + + Examples: + >>> rhs = normalize_expr_rhs({ + ... "kind": "expr", + ... "state": ["x", "y"], + ... "equations": {"x": "-y", "y": "x"}, + ... }) + >>> rhs.state_names + ('x', 'y') + >>> len(rhs.equations) + 2 """ state_raw = _ensure_str_list(spec.get("state"), name="state") if len(state_raw) != len(set(state_raw)): @@ -624,6 +647,15 @@ def _build_transition_equations_ir( # noqa: C901, PLR0912, PLR0913, PLR0914, PL sys.setrecursionlimit(old_limit) def _sum_terms(terms: list[Expr]) -> Expr: + """Combine ``terms`` into a single :class:`Apply` (or a literal 0). + + Returns ``Literal(0.0)`` for an empty list, the lone term unchanged + for a single-element list, and a flat ``Apply('+', ...)`` for two + or more terms. + + Returns: + Combined IR expression. + """ if not terms: return Literal(value=0.0) if len(terms) == 1: diff --git a/src/op_system/_normalize_ir.py b/src/op_system/_normalize_ir.py index 899603b..63baa1d 100644 --- a/src/op_system/_normalize_ir.py +++ b/src/op_system/_normalize_ir.py @@ -11,7 +11,7 @@ import contextlib import sys from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple if TYPE_CHECKING: import re @@ -228,6 +228,11 @@ def _strip_time_axis_in_expr( return expr def _rewrite(match: re.Match[str]) -> str: + """Rewrite one ``base[time, ...]`` match by dropping the time axis. + + Returns: + Replacement text, or the original match if no rewrite applies. + """ base = match.group(1) full = tv_full_axes.get(base) if full is None: @@ -535,7 +540,105 @@ def _build_equations_ir( sys.setrecursionlimit(old_limit) -def _build_aliases_ir_from_raw( # noqa: C901 +def _parse_alias_body_to_ir(raw_name: str, expr_str: object) -> Expr: + """Parse a single alias body string into IR with diagnostics. + + Args: + raw_name: Original alias key, used in error messages. + expr_str: Raw alias body. Must be a non-empty string. + + Returns: + Parsed alias IR (helpers already lowered). + + Raises: + InvalidRhsSpecError: If ``expr_str`` is not a non-empty string or + cannot be parsed as a valid expression. + """ + expr_s = expr_str.strip() if isinstance(expr_str, str) else "" + if not expr_s: + raise InvalidRhsSpecError( + detail=f"aliases[{raw_name!r}] must be a non-empty string" + ) + try: + return parse_expr_to_ir(expr_s, lower_helpers=True) + except Exception as exc: + raise InvalidRhsSpecError( + detail=f"aliases[{raw_name!r}] has invalid expression: {exc}" + ) from exc + + +class _AliasExpansionContext(NamedTuple): + """Bundled context for per-alias IR expansion. + + Attributes: + alias_template_map: Map from canonical alias name to expansion + assignments. + axes: Resolved axes list passed through to expansion helpers. + shaped: Shaped-parameter axis declarations. + ax_lookup: Axis-coord lookup dictionary. + """ + + alias_template_map: Mapping[str, list[tuple[str, dict[str, str]]]] + axes: list[dict[str, Any]] + shaped: Mapping[str, tuple[str, ...]] + ax_lookup: Mapping[str, list[str]] + + +def _populate_alias_parsed_maps( + raw_name: str, + canonical_name: str, + ir_raw: Expr, + ctx: _AliasExpansionContext, + parsed_maps: tuple[dict[str, Expr], dict[str, Expr]], +) -> None: + """Expand one parsed alias and write entries into the parsed maps. + + Branches on whether the alias is templated (i.e. has a non-empty + expansion in ``ctx.alias_template_map``): templated aliases are + expanded once per assignment; bare aliases are stored under + ``raw_name``. + + Args: + raw_name: Original alias key (used as the dict key for + non-templated aliases). + canonical_name: Bracket-normalized form of ``raw_name`` (used to + look up template expansions). + ir_raw: Parsed alias body. + ctx: Bundled expansion context. + parsed_maps: Pair ``(reduce_parsed, full_parsed)`` of output maps + populated with reduce-form and point-expanded IRs. + """ + from op_system._ir_expand import expand_reduce_pointwise # noqa: PLC0415 + + reduce_parsed, full_parsed = parsed_maps + if canonical_name in ctx.alias_template_map: + for expanded_name, assignment in ctx.alias_template_map[canonical_name]: + ir_tmpl = expand_inline_templates( + ir_raw, + assignment=assignment, + shaped_params=ctx.shaped, + axis_lookup=ctx.ax_lookup, + ) + reduce_parsed[expanded_name] = ir_tmpl + full_parsed[expanded_name] = expand_reduce_pointwise( + ir_tmpl, + axes=list(ctx.axes), + shaped_params=ctx.shaped, + lhs_assignment=assignment, + axis_coords=ctx.ax_lookup, + ) + return + reduce_parsed[raw_name] = ir_raw + full_parsed[raw_name] = expand_reduce_pointwise( + ir_raw, + axes=list(ctx.axes), + shaped_params=ctx.shaped, + lhs_assignment={}, + axis_coords=ctx.ax_lookup, + ) + + +def _build_aliases_ir_from_raw( aliases_raw: Mapping[str, str], *, axes: list[dict[str, Any]], @@ -550,12 +653,7 @@ def _build_aliases_ir_from_raw( # noqa: C901 Returns: ``(aliases_ir, aliases_ir_reduce, alias_template_map)``. - - Raises: - InvalidRhsSpecError: If any alias body is an invalid expression. """ - from op_system._ir_expand import expand_reduce_pointwise # noqa: PLC0415 - shaped = shaped_params or {} ax_lookup = dict(axis_lookup or {}) @@ -571,45 +669,26 @@ def _build_aliases_ir_from_raw( # noqa: C901 for raw_name, expr_str in aliases_raw.items(): canonical_name = _normalize_bracket_key(raw_name) - expr_s = expr_str.strip() if isinstance(expr_str, str) else "" - if not expr_s: - raise InvalidRhsSpecError( - detail=f"aliases[{raw_name!r}] must be a non-empty string" - ) - try: - ir_raw = parse_expr_to_ir(expr_s, lower_helpers=True) - except Exception as exc: - raise InvalidRhsSpecError( - detail=f"aliases[{raw_name!r}] has invalid expression: {exc}" - ) from exc - - if canonical_name in alias_template_map: - for expanded_name, assignment in alias_template_map[canonical_name]: - ir_tmpl = expand_inline_templates( - ir_raw, - assignment=assignment, - shaped_params=shaped, - axis_lookup=ax_lookup, - ) - reduce_parsed[expanded_name] = ir_tmpl - full_parsed[expanded_name] = expand_reduce_pointwise( - ir_tmpl, - axes=list(axes), - shaped_params=shaped, - lhs_assignment=assignment, - axis_coords=ax_lookup, - ) - else: - reduce_parsed[raw_name] = ir_raw - full_parsed[raw_name] = expand_reduce_pointwise( - ir_raw, - axes=list(axes), - shaped_params=shaped, - lhs_assignment={}, - axis_coords=ax_lookup, - ) + ir_raw = _parse_alias_body_to_ir(raw_name, expr_str) + _populate_alias_parsed_maps( + raw_name, + canonical_name, + ir_raw, + _AliasExpansionContext( + alias_template_map=alias_template_map, + axes=axes, + shaped=shaped, + ax_lookup=ax_lookup, + ), + (reduce_parsed, full_parsed), + ) def _inline_all(parsed: dict[str, Expr]) -> dict[str, Expr]: + """Inline alias references through the alias map (cycle-safe). + + Returns: + Map from alias name to fully-inlined IR. + """ memo: dict[int, frozenset[str]] = {} cycle_ok = False with contextlib.suppress(ValueError, RecursionError): diff --git a/src/op_system/_symbols.py b/src/op_system/_symbols.py index 13a3800..026ac11 100644 --- a/src/op_system/_symbols.py +++ b/src/op_system/_symbols.py @@ -25,6 +25,12 @@ class ExpressionString: names: frozenset[str] = field(init=False) def __post_init__(self) -> None: + """Parse ``source`` once and freeze the cached AST plus name set. + + Raises: + InvalidExpressionError: If ``source`` is not a syntactically + valid Python expression. + """ try: parsed = ast.parse(self.source, mode="eval") except SyntaxError as exc: diff --git a/src/op_system/_templates.py b/src/op_system/_templates.py index cac5eff..606078c 100644 --- a/src/op_system/_templates.py +++ b/src/op_system/_templates.py @@ -451,6 +451,16 @@ def _apply_template_substitutions( # Inline placeholder syntax without an explicit template map entry. def _inline_replacer(match: re.Match[str]) -> str: + """Rewrite one ``base[axis, ...]`` occurrence using ``assignment``. + + Shaped-parameter subscripts get integer-coord indices substituted; + all other matches are rendered into the templated scalar name + produced by :func:`_render_template_name`. Matches whose + placeholders are not all bound by ``assignment`` are left intact. + + Returns: + Replacement text, or the original match if not applicable. + """ inner_base = match.group(1) inner = match.group(2) phs = [p.strip() for p in inner.split(",") if p.strip() and "=" not in p] diff --git a/src/op_system/_typing.py b/src/op_system/_typing.py index 6e81a5d..86a9730 100644 --- a/src/op_system/_typing.py +++ b/src/op_system/_typing.py @@ -34,18 +34,26 @@ class Array(Protocol): """ @property - def shape(self) -> tuple[int, ...]: ... + def shape(self) -> tuple[int, ...]: + """Return the array shape as a tuple of ``int``.""" + ... @property - def dtype(self) -> object: ... + def dtype(self) -> object: + """Return the array dtype object (namespace-specific).""" + ... def __array_namespace__( # noqa: PLW3201 self, *, api_version: Any = None, # noqa: ANN401 - ) -> object: ... + ) -> object: + """Return the Array-API namespace that owns this array.""" + ... - def item(self) -> Any: ... # noqa: ANN401 + def item(self) -> Any: # noqa: ANN401 + """Return the underlying scalar value (for 0-d arrays).""" + ... __all__ = ["Array"] diff --git a/src/op_system/_vectorize.py b/src/op_system/_vectorize.py index 2f38ed3..21e2560 100644 --- a/src/op_system/_vectorize.py +++ b/src/op_system/_vectorize.py @@ -291,6 +291,15 @@ def _lower_multicell_sym_ir_to_ast( # noqa: C901 ) def _lower(node: Expr) -> ast.expr | None: + """Best-effort IR -> AST lowering for the cell-equation rewriter. + + Returns ``None`` (rather than raising) for any node shape outside + the supported subset, which lets the caller fall back to the + generic per-cell lowering. + + Returns: + Lowered AST expression, or ``None`` if unsupported. + """ if isinstance(node, Sym): return cell_ast.get(node.name) if isinstance(node, Literal): @@ -647,6 +656,7 @@ def _bin_flat( base_flat: int = base_flat, vec_idx: tuple[int, ...] = vec_idx, ) -> int: + """Return the flat cell index for a vec-axis coord position.""" f = base_flat for k_pos, ax_pos in enumerate(vec_idx): f += vec_coord_idx[k_pos] * strides[ax_pos] @@ -982,6 +992,19 @@ def make_vectorized_eval_fn(plan: _VectorPlan) -> EvalFn: # noqa: C901, PLR0915 extra_param_buffers = plan.extra_param_buffers def eval_fn(t: object, y: object, **params: object) -> Float64Array: # noqa: C901, PLR0912, PLR0914, PLR0915 + """Vectorized evaluation closure produced by :func:`build_eval_fn`. + + Computes the RHS for the entire state vector at once using the + compiled per-bin AST callables, gathered into a single output + array in ``y``'s array namespace. + + Returns: + ``dydt`` array of shape ``(n_state,)`` in ``y``'s namespace. + + Raises: + ValueError: If a required parameter is missing or has the + wrong shape. + """ xp = _namespace_of(y) _check_numeric_dtype(xp, getattr(y, "dtype", None)) y_arr = _validate_state_vector(y, n_state=n_state) diff --git a/src/op_system/compile.py b/src/op_system/compile.py index 12cd0f6..542e842 100644 --- a/src/op_system/compile.py +++ b/src/op_system/compile.py @@ -23,6 +23,7 @@ import importlib import warnings from dataclasses import dataclass, field +from itertools import starmap from types import MappingProxyType from typing import ( TYPE_CHECKING, @@ -56,7 +57,11 @@ class _Indexable(Protocol): - def __getitem__(self, idx: int | slice) -> object: ... + """Minimal indexable protocol used to type-check state-vector access.""" + + def __getitem__(self, idx: int | slice) -> object: + """Return the element at ``idx``.""" + ... def _namespace_of(y: object) -> Any: # noqa: ANN401 @@ -193,9 +198,9 @@ def _raise_unsupported_feature(*, feature: str, detail: str | None = None) -> No class EvalFn(Protocol): """Callable RHS evaluator supporting runtime parameter kwargs.""" - def __call__( # noqa: D102 - self, t: object, y: object, **params: object - ) -> Float64Array: ... + def __call__(self, t: object, y: object, **params: object) -> Float64Array: + """Evaluate the RHS at time ``t`` and state ``y`` with bound parameters.""" + ... @dataclass(frozen=True, slots=True) @@ -240,6 +245,11 @@ def bind( params_dict = dict(params) def rhs(t: object, y: object) -> Float64Array: + """Two-argument RHS with parameters bound by the enclosing call. + + Returns: + ``dydt`` array in the namespace of ``y``. + """ return self.eval_fn(t, y, **params_dict) return rhs @@ -365,6 +375,15 @@ def _parse_expr(expr: str) -> ast.Expression: def _validate_call(func: ast.AST, *, expr: str) -> None: + """Validate that an AST call target is on the safe-eval allowlist. + + Permits attribute calls under whitelisted roots (currently ``np.``) and + bare names that match registered helper functions. + + Args: + func: ``func`` slot of an ``ast.Call`` node. + expr: Original expression string, used in diagnostics. + """ if isinstance(func, ast.Attribute): if not isinstance(func.value, ast.Name): _raise_invalid_expression(detail=f"invalid call root in {expr!r}") @@ -485,10 +504,7 @@ def _collect_eq_code( reserved_names=reserved_names, ) cse_code = tuple((name, _compile_expr(name, expr)) for name, expr in bindings) - eq_code = [ # noqa: FURB140 - _compile_expr(expr_s, expr_ir) - for expr_s, expr_ir in zip(equations, rewritten, strict=True) - ] + eq_code = list(starmap(_compile_expr, zip(equations, rewritten, strict=True))) return cse_code, eq_code return (), [_compile_expr(expr) for expr in equations] @@ -631,6 +647,15 @@ def _make_eval_fn( ) def eval_fn(t: object, y: object, **params: object) -> Float64Array: + """Namespace-polymorphic compiled RHS body. + + Infers the array namespace from ``y`` at call time, builds the + evaluation environment from state, parameters and aliases, and + returns the equation outputs stacked in ``y``'s namespace. + + Returns: + ``dydt`` array of shape ``(n_state,)`` in ``y``'s namespace. + """ xp = _namespace_of(y) _check_numeric_dtype(xp, getattr(y, "dtype", None)) y_arr = _validate_state_vector(y, n_state=n_state) @@ -642,12 +667,14 @@ def eval_fn(t: object, y: object, **params: object) -> Float64Array: env.update(params) def _sum_state() -> object: + """Return the sum of all state variables (``sum_state()`` helper).""" values = [v for k, v in env.items() if k in name_to_idx] if not values: return xp.asarray(0.0) return xp.sum(xp.stack(values)) def _sum_prefix(prefix: object) -> object: + """Return the sum of state variables whose names start with ``prefix``.""" pfx = str(prefix) values = [ v for k, v in env.items() if k.startswith(pfx) and k in name_to_idx @@ -782,6 +809,11 @@ def _wrap_eval_fn_for_time_varying( ) def wrapped(t: object, y: object, **params: object) -> Float64Array: + """Interpolate each time-varying parameter at ``t`` then dispatch. + + Returns: + ``dydt`` array produced by the wrapped ``eval_fn``. + """ xp = _namespace_of(y) for name, axis_pos in plan: if name not in params: @@ -820,6 +852,19 @@ def compile_rhs(rhs: NormalizedRhs, *, xp: object | None = None) -> CompiledRhs: Returns: A `CompiledRhs` containing an `eval_fn(t, y, **params) -> dydt`. + + Examples: + >>> import numpy as np + >>> from op_system.specs import normalize_rhs + >>> rhs = normalize_rhs({ + ... "kind": "expr", + ... "state": ["x"], + ... "equations": {"x": "2.0 * x"}, + ... }) + >>> compiled = compile_rhs(rhs) + >>> y = np.array([3.0]) + >>> compiled.eval_fn(0.0, y) + array([6.]) """ if xp is not None: warnings.warn(