diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py new file mode 100644 index 000000000..b2da9094d --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING + +from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task + +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph + + +class CompletionTracker: + """Tracks which cells (column, row_group, row_index) are done. + + Row indices are local to their row group (0-based). + + Use ``with_graph`` to create a frontier-enabled tracker where + ``get_ready_tasks`` returns in O(frontier) instead of scanning all + columns x rows x row groups. + """ + + def __init__(self) -> None: + # row_group → column → set of completed local row indices + self._completed: dict[int, dict[str, set[int]]] = defaultdict(lambda: defaultdict(set)) + # row_group → set of dropped row indices + self._dropped: dict[int, set[int]] = defaultdict(set) + + self._graph: ExecutionGraph | None = None + self._row_group_sizes: dict[int, int] = {} + self._batch_complete: dict[int, set[str]] = defaultdict(set) + self._frontier: set[Task] = set() + + @classmethod + def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> CompletionTracker: + """Create a frontier-enabled tracker backed by an execution graph.""" + tracker = cls() + tracker._graph = graph + tracker._row_group_sizes = {rg_id: size for rg_id, size in row_groups} + tracker._seed_frontier() + return tracker + + def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: + self._validate_row_group(row_group) + self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete") + self._completed[row_group][column].add(row_index) + if self._graph is not None: + self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) + self._enqueue_downstream(column, row_group, row_index=row_index) + + def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None: + expected = self._validate_row_group(row_group) + self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete") + if expected is not None and row_group_size != expected: + raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") + self._completed[row_group][column] = set(range(row_group_size)) + self._batch_complete[row_group].add(column) + if self._graph is not None: + self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) + self._enqueue_downstream(column, row_group, row_index=None) + + def is_complete(self, ref: SliceRef) -> bool: + return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) + + def is_all_complete(self, cells: list[SliceRef]) -> bool: + """Check whether all the given cells are done. + + A ``row_index`` of ``None`` means the entire batch for that column must + have been completed via ``mark_row_range_complete``. + """ + for ref in cells: + if ref.row_index is None: + if ref.column not in self._batch_complete.get(ref.row_group, set()): + return False + elif not self.is_complete(ref): + return False + return True + + def drop_row(self, row_group: int, row_index: int) -> None: + self._validate_row_group(row_group) + self._dropped[row_group].add(row_index) + if self._graph is not None: + # Remove cell tasks for this row from the frontier + for col in self._graph.columns: + self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) + # Dropping a row may unblock batch downstream tasks + self._reevaluate_batch_tasks(row_group) + + def is_dropped(self, row_group: int, row_index: int) -> bool: + return row_index in self._dropped.get(row_group, set()) + + def is_row_group_complete( + self, + row_group: int, + row_group_size: int, + all_columns: list[str], + ) -> bool: + """All non-dropped rows have all columns done.""" + dropped = self._dropped.get(row_group, set()) + completed = self._completed.get(row_group, {}) + for ri in range(row_group_size): + if ri in dropped: + continue + for col in all_columns: + if ri not in completed.get(col, set()): + return False + return True + + def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: + """Return all currently dispatchable tasks from the frontier. + + Excludes already-dispatched/in-flight tasks. + """ + return [t for t in self._frontier if t not in dispatched] + + def _seed_frontier(self) -> None: + """Populate the frontier with root tasks (columns with no upstream deps).""" + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + for col in self._graph.get_root_columns(): + strategy = self._graph.get_strategy(col) + for rg_id, rg_size in self._row_group_sizes.items(): + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(rg_size): + self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) + else: + self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) + + def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None: + """Add newly-ready downstream tasks to the frontier.""" + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + rg_completed = self._completed.get(row_group, {}) + rg_dropped = self._dropped.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, set()) + rg_size = self._row_group_sizes[row_group] + + for down in self._graph.get_downstream_columns(column): + batch_ups, cell_ups = self._graph.split_upstream_by_strategy(down) + + if any(up not in rg_batch_complete for up in batch_ups): + continue + + down_strategy = self._graph.get_strategy(down) + + if down_strategy == GenerationStrategy.CELL_BY_CELL: + cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups] + if row_index is not None: + # Cell completion: only check the same row + down_completed = rg_completed.get(down, set()) + if ( + row_index not in rg_dropped + and row_index not in down_completed + and all(row_index in s for s in cell_up_completed) + ): + task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell") + self._frontier.add(task) + else: + # Batch completion: check all non-dropped, non-complete rows + down_completed = rg_completed.get(down, set()) + for ri in range(rg_size): + if ri in rg_dropped or ri in down_completed: + continue + if all(ri in s for s in cell_up_completed): + task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell") + self._frontier.add(task) + else: + # FULL_COLUMN downstream: ready when all cell upstreams are fully complete + if down not in rg_batch_complete and self._are_cell_ups_complete( + cell_ups, rg_completed, rg_size, rg_dropped + ): + task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") + self._frontier.add(task) + + def _reevaluate_batch_tasks(self, row_group: int) -> None: + """Check if any batch tasks became ready after a row was dropped.""" + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + rg_completed = self._completed.get(row_group, {}) + rg_dropped = self._dropped.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, set()) + rg_size = self._row_group_sizes[row_group] + + for col in self._graph.get_topological_order(): + if self._graph.get_strategy(col) != GenerationStrategy.FULL_COLUMN: + continue + if col in rg_batch_complete: + continue + batch_ups, cell_ups = self._graph.split_upstream_by_strategy(col) + if any(up not in rg_batch_complete for up in batch_ups): + continue + if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): + task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") + self._frontier.add(task) + + def _are_cell_ups_complete( + self, + cell_ups: list[str], + rg_completed: dict[str, set[int]], + rg_size: int, + rg_dropped: set[int], + ) -> bool: + """Check all non-dropped rows are complete for each cell-by-cell upstream column.""" + for up in cell_ups: + up_completed = rg_completed.get(up, set()) + for ri in range(rg_size): + if ri not in rg_dropped and ri not in up_completed: + return False + return True + + def _validate_strategy(self, column: str, expected: GenerationStrategy, method: str) -> None: + """Validate that *column* matches the expected strategy in graph-enabled mode.""" + if self._graph is None: + return + actual = self._graph.get_strategy(column) + if actual != expected: + raise ValueError(f"{method}() requires {expected.value} strategy, but column '{column}' has {actual.value}") + + def _validate_row_group(self, row_group: int) -> int | None: + """Validate row-group id in graph-enabled mode and return its expected size.""" + if self._graph is None: + return None + expected = self._row_group_sizes.get(row_group) + if expected is None: + known = sorted(self._row_group_sizes) + raise ValueError(f"Unknown row_group {row_group}. Known row_groups: {known}") + return expected diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py new file mode 100644 index 000000000..29db09c83 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -0,0 +1,260 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from collections import deque + +from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.dataset_builders.multi_column_configs import ( + DatasetBuilderColumnConfigT, + MultiColumnConfig, +) +from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError +from data_designer.engine.dataset_builders.utils.task_model import SliceRef + + +class ExecutionGraph: + """Column-level static execution graph built from column configs. + + Nodes are columns (O(C)); edges are dependency relationships (O(C²) worst-case). + The graph is fixed for the lifetime of a run — runtime readiness is tracked + separately by ``CompletionTracker``. + """ + + def __init__(self) -> None: + self._upstream: dict[str, set[str]] = {} + self._downstream: dict[str, set[str]] = {} + self._strategies: dict[str, GenerationStrategy] = {} + self._side_effect_map: dict[str, str] = {} + self._columns: list[str] = [] + self._topological_order_cache: list[str] | None = None + self._upstream_by_strategy_cache: dict[str, tuple[list[str], list[str]]] = {} + + @property + def columns(self) -> list[str]: + """All column names in insertion order.""" + return list(self._columns) + + @classmethod + def create( + cls, + column_configs: list[DatasetBuilderColumnConfigT], + strategies: dict[str, GenerationStrategy], + ) -> ExecutionGraph: + """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. + + Args: + column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. + strategies: Map of column name → ``GenerationStrategy``, obtained from + each generator's ``get_generation_strategy()``. + """ + graph = cls() + + # First pass: register all columns, strategies, and side-effect mappings + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + graph.add_column(name, strategies[name]) + + for se_col in sub.side_effect_columns: + graph.set_side_effect(se_col, name) + + known_columns = set(graph.columns) + + # Second pass: build edges + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + for req in sub.required_columns: + resolved = graph.resolve_side_effect(req) + if resolved not in known_columns: + raise ValueError( + f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." + ) + if resolved == name: + continue # skip self-dependency + graph.add_edge(upstream=resolved, downstream=name) + + # Validate acyclicity + graph.get_topological_order() + + return graph + + def add_column(self, name: str, strategy: GenerationStrategy) -> None: + """Register a column and its generation strategy.""" + if name in self._strategies: + raise ValueError(f"Column '{name}' is already registered.") + self._columns.append(name) + self._strategies[name] = strategy + + def add_edge(self, upstream: str, downstream: str) -> None: + """Add a dependency edge: *downstream* depends on *upstream*.""" + self._upstream.setdefault(downstream, set()).add(upstream) + self._downstream.setdefault(upstream, set()).add(downstream) + + def set_side_effect(self, side_effect_col: str, producer: str) -> None: + """Map a side-effect column name to its producing column.""" + self._side_effect_map[side_effect_col] = producer + + def resolve_side_effect(self, column: str) -> str: + """Resolve a column name through the side-effect map. + + If a real column exists with the same name as a side-effect alias, + the real column wins. + """ + if column in self._strategies: + return column + return self._side_effect_map.get(column, column) + + def get_upstream_columns(self, column: str) -> set[str]: + """Direct dependencies of *column*.""" + return set(self._upstream.get(column, set())) + + def get_downstream_columns(self, column: str) -> set[str]: + """Columns that depend on *column*.""" + return set(self._downstream.get(column, set())) + + def get_strategy(self, column: str) -> GenerationStrategy: + return self._strategies[column] + + def get_root_columns(self) -> list[str]: + """Columns with no upstream dependencies, in topological order.""" + return [col for col in self.get_topological_order() if not self._upstream.get(col)] + + def split_upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: + """Split upstream columns into (batch, cell_by_cell) by strategy. Cached.""" + cached = self._upstream_by_strategy_cache.get(column) + if cached is not None: + return cached + batch: list[str] = [] + cell: list[str] = [] + for up_col in self.get_upstream_columns(column): + if self._strategies[up_col] == GenerationStrategy.CELL_BY_CELL: + cell.append(up_col) + else: + batch.append(up_col) + result = (batch, cell) + self._upstream_by_strategy_cache[column] = result + return result + + def get_topological_order(self) -> list[str]: + """Return a valid topological ordering of columns (Kahn's algorithm). + + Result is cached after first successful computation since the graph is + immutable after construction. + """ + if self._topological_order_cache is not None: + return list(self._topological_order_cache) + + in_degree: dict[str, int] = {col: 0 for col in self._columns} + for col, deps in self._upstream.items(): + if col in in_degree: + in_degree[col] = len(deps) + + queue = deque(col for col, deg in in_degree.items() if deg == 0) + order: list[str] = [] + while queue: + col = queue.popleft() + order.append(col) + for child in self._downstream.get(col, set()): + if child in in_degree: + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + if len(order) != len(self._columns): + raise DAGCircularDependencyError( + f"The execution graph contains cyclic dependencies. Resolved {len(order)}/{len(self._columns)} columns." + ) + + self._topological_order_cache = order + return list(order) + + def get_longest_dependency_chain(self) -> list[str]: + """Longest dependency chain (by number of columns).""" + order = self.get_topological_order() + if not order: + return [] + dist: dict[str, int] = {col: 0 for col in order} + pred: dict[str, str | None] = {col: None for col in order} + + for col in order: + for child in self._downstream.get(col, set()): + if dist[col] + 1 > dist[child]: + dist[child] = dist[col] + 1 + pred[child] = col + + end = max(order, key=lambda c: dist[c]) + path: list[str] = [] + cur: str | None = end + while cur is not None: + path.append(cur) + cur = pred[cur] + path.reverse() + return path + + def compute_task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: + """Exact task count per column before the run starts. + + Cell-by-cell columns produce ``num_records`` tasks. + Full-column columns (including from-scratch) produce ``ceil(num_records / buffer_size)`` tasks. + """ + if buffer_size <= 0: + raise ValueError(f"buffer_size must be a positive integer, got {buffer_size}") + num_row_groups = math.ceil(num_records / buffer_size) + counts: dict[str, int] = {} + for col in self._columns: + strat = self._strategies[col] + if strat == GenerationStrategy.CELL_BY_CELL: + counts[col] = num_records + else: + counts[col] = num_row_groups + return counts + + def compute_cell_dependencies( + self, + column: str, + row_group: int, + row_index: int | None, + row_group_size: int, + ) -> list[SliceRef]: + """Derive cell-level deps on demand from column-level DAG + strategy. + + Returns a list of ``SliceRef`` that must be complete before this task can run. + """ + deps: list[SliceRef] = [] + for up_col in self.get_upstream_columns(column): + up_strategy = self._strategies[up_col] + if up_strategy == GenerationStrategy.CELL_BY_CELL: + if row_index is not None: + deps.append(SliceRef(up_col, row_group, row_index)) + else: + for ri in range(row_group_size): + deps.append(SliceRef(up_col, row_group, ri)) + else: + deps.append(SliceRef(up_col, row_group, None)) + return deps + + def to_mermaid(self) -> str: + """Mermaid diagram string with strategy annotations.""" + lines = ["graph TD"] + for col in self._columns: + strat = self._strategies[col] + label = f"{col} [{strat.value}]" + lines.append(f' {col}["{label}"]') + for col in self._columns: + for dep in sorted(self._upstream.get(col, set())): + lines.append(f" {dep} --> {col}") + return "\n".join(lines) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py new file mode 100644 index 000000000..574c594c1 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + + +@dataclass(frozen=True, order=True) +class SliceRef: + """Reference to a slice of the execution grid: a single cell or a full row group.""" + + column: str + row_group: int + row_index: int | None = None + + +@dataclass(frozen=True) +class Task: + """A unit of work for the async scheduler.""" + + column: str + row_group: int + row_index: int | None # None for batch/full-column tasks + task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] + + +@dataclass +class TaskResult: + """Outcome of a completed task.""" + + task: Task + status: Literal["success", "error"] + output: Any = None + error: Exception | None = None + retryable: bool = False + + +@dataclass +class TaskTrace: + """Timing trace for a single task. Only created when tracing is enabled.""" + + column: str + row_group: int + row_index: int | None + task_type: str + dispatched_at: float = 0.0 + slot_acquired_at: float = 0.0 + completed_at: float = 0.0 + status: str = "" + error: str | None = None + + @classmethod + def from_task(cls, task: Task) -> TaskTrace: + return cls( + column=task.column, + row_group=task.row_group, + row_index=task.row_index, + task_type=task.task_type, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index dec52401a..9cd9bcc62 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( ProviderError, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py index 1b65e2dde..cc9feefea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient, LiteLLMRouter __all__ = ["LiteLLMBridgeClient", "LiteLLMRouter"] diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py new file mode 100644 index 000000000..b0e9f8024 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMTextColumnConfig, + SamplerColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task + +MODEL_ALIAS = "stub" + + +def _build_simple_graph() -> ExecutionGraph: + """topic (full-column) → question (cell-by-cell) → score (full-column).""" + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="question", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="score", expr="{{ question }}"), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "score": GenerationStrategy.FULL_COLUMN, + } + return ExecutionGraph.create(configs, strategies) + + +@dataclass +class ReadyTasksFixture: + tracker: CompletionTracker + dispatched: set[Task] + + +@pytest.fixture() +def ready_ctx() -> ReadyTasksFixture: + """CompletionTracker wired to the simple 3-column graph with one row group of size 3.""" + graph = _build_simple_graph() + return ReadyTasksFixture( + tracker=CompletionTracker.with_graph(graph, [(0, 3)]), + dispatched=set(), + ) + + +# -- mark_cell_complete / is_complete -------------------------------------- + + +def test_mark_and_check_complete() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", row_group=0, row_index=0) + + assert tracker.is_complete(SliceRef("col_a", 0, 0)) + assert not tracker.is_complete(SliceRef("col_a", 0, 1)) + assert not tracker.is_complete(SliceRef("col_a", 1, 0)) + assert not tracker.is_complete(SliceRef("col_b", 0, 0)) + + +def test_mark_row_range_complete() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", row_group=0, row_group_size=3) + + assert tracker.is_complete(SliceRef("col_a", 0, 0)) + assert tracker.is_complete(SliceRef("col_a", 0, 1)) + assert tracker.is_complete(SliceRef("col_a", 0, 2)) + assert not tracker.is_complete(SliceRef("col_a", 0, 3)) + + +def test_mark_row_range_complete_raises_on_size_mismatch(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="Row-group size mismatch"): + ready_ctx.tracker.mark_row_range_complete("topic", row_group=0, row_group_size=2) + + +def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="Unknown row_group"): + ready_ctx.tracker.mark_cell_complete("question", row_group=999, row_index=0) + + +# -- is_all_complete ----------------------------------------------------------- + + +def test_all_complete_cell_level() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 1) + + assert tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 1)]) + assert not tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 2)]) + + +def test_all_complete_batch_level() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", 0, 3) + + assert tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_batch_single_cell_not_sufficient() -> None: + """mark_cell_complete on one row must NOT make is_all_complete return True for batch check.""" + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_batch_not_present() -> None: + tracker = CompletionTracker() + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) + + +def test_all_complete_empty_list() -> None: + tracker = CompletionTracker() + assert tracker.is_all_complete([]) + + +# -- drop_row / is_dropped ------------------------------------------------- + + +def test_drop_row() -> None: + tracker = CompletionTracker() + tracker.drop_row(row_group=0, row_index=2) + + assert tracker.is_dropped(0, 2) + assert not tracker.is_dropped(0, 0) + assert not tracker.is_dropped(1, 2) + + +# -- is_row_group_complete -------------------------------------------------- + + +def test_row_group_complete() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", 0, 3) + tracker.mark_row_range_complete("col_b", 0, 3) + + assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_incomplete() -> None: + tracker = CompletionTracker() + tracker.mark_row_range_complete("col_a", 0, 3) + + assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_complete_with_dropped_rows() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 2) + tracker.mark_cell_complete("col_b", 0, 0) + tracker.mark_cell_complete("col_b", 0, 2) + tracker.drop_row(0, 1) # row 1 is dropped + + assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_not_complete_missing_non_dropped() -> None: + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_b", 0, 0) + tracker.drop_row(0, 1) + # row 2 is not dropped and not complete + + assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +# -- get_ready_tasks -------------------------------------------------------- + + +def test_get_ready_tasks_seeds_first(ready_ctx: ReadyTasksFixture) -> None: + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + assert len(ready) == 1 + assert ready[0].column == "topic" + assert ready[0].task_type == "batch" + + +def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 3 + assert all(t.task_type == "cell" for t in question_tasks) + assert {t.row_index for t in question_tasks} == {0, 1, 2} + + +def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + ready1 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready_ctx.dispatched.update(ready1) + + ready2 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + assert len(ready2) == 0 + + +def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + ready_ctx.tracker.drop_row(0, 1) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 2 + assert {t.row_index for t in question_tasks} == {0, 2} + + +def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) -> None: + """Dropping the last incomplete CELL_BY_CELL row should make downstream FULL_COLUMN ready.""" + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + ready_ctx.tracker.mark_cell_complete("question", 0, 0) + ready_ctx.tracker.mark_cell_complete("question", 0, 1) + # question[2] never completes -- drop it instead + ready_ctx.tracker.drop_row(0, 2) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 1 + assert score_tasks[0].task_type == "batch" + + +def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + ready_ctx.tracker.mark_cell_complete("question", 0, 0) + ready_ctx.tracker.mark_cell_complete("question", 0, 1) + # question[0,2] not done yet + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 0 + + +def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + for ri in range(3): + ready_ctx.tracker.mark_cell_complete("question", 0, ri) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 1 + assert score_tasks[0].task_type == "batch" + + +def test_get_ready_tasks_multiple_row_groups() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(0, 3), (1, 2)]) + dispatched: set[Task] = set() + + tracker.mark_row_range_complete("topic", 0, 3) + tracker.mark_row_range_complete("topic", 1, 2) + + ready = tracker.get_ready_tasks(dispatched) + + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 + + +def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + + topic_tasks = [t for t in ready if t.column == "topic"] + assert len(topic_tasks) == 0 + + +# -- Strategy-safe completion API ------------------------------------------ + + +def test_mark_cell_complete_raises_for_full_column_strategy(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="mark_cell_complete.*requires cell_by_cell.*full_column"): + ready_ctx.tracker.mark_cell_complete("topic", row_group=0, row_index=0) + + +def test_mark_row_range_complete_raises_for_cell_by_cell_strategy(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + with pytest.raises(ValueError, match="mark_row_range_complete.*requires full_column.*cell_by_cell"): + ready_ctx.tracker.mark_row_range_complete("question", row_group=0, row_group_size=3) + + +# -- Re-enqueue regression tests ------------------------------------------- + + +def test_completed_cell_not_reenqueued_after_later_upstream() -> None: + """A → B → C chain: completing C then firing a late upstream event must not re-enqueue C.""" + graph = _build_simple_graph() + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + dispatched: set[Task] = set() + + # Complete the full pipeline + tracker.mark_row_range_complete("topic", 0, 2) + tracker.mark_cell_complete("question", 0, 0) + tracker.mark_cell_complete("question", 0, 1) + tracker.mark_row_range_complete("score", 0, 2) + + # Fire a late upstream cell event after score is already done + tracker.mark_cell_complete("question", 0, 0) + + ready = tracker.get_ready_tasks(dispatched) + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 0 + + +def test_completed_batch_not_reenqueued_by_upstream_cell() -> None: + """After a FULL_COLUMN downstream is completed, a late cell upstream event must not re-add it.""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="gen", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="agg", expr="{{ gen }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "gen": GenerationStrategy.CELL_BY_CELL, + "agg": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + dispatched: set[Task] = set() + + # Complete seed and gen[0] — agg not ready yet + tracker.mark_row_range_complete("seed", 0, 2) + tracker.mark_cell_complete("gen", 0, 0) + + ready = tracker.get_ready_tasks(dispatched) + assert not any(t.column == "agg" for t in ready) + + # Complete gen[1] — agg becomes ready + tracker.mark_cell_complete("gen", 0, 1) + ready = tracker.get_ready_tasks(dispatched) + assert any(t.column == "agg" for t in ready) + + # Complete agg, then verify it doesn't reappear + tracker.mark_row_range_complete("agg", 0, 2) + ready = tracker.get_ready_tasks(dispatched) + assert not any(t.column == "agg" for t in ready) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py new file mode 100644 index 000000000..9d2fa69c4 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -0,0 +1,450 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + Score, + ValidationColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.config.utils.code_lang import CodeLang +from data_designer.config.validator_params import CodeValidatorParams +from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig +from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.engine.dataset_builders.utils.task_model import SliceRef + +MODEL_ALIAS = "stub-model-alias" + + +# -- Fixtures ---------------------------------------------------------------- + + +@pytest.fixture() +def simple_pipeline_configs() -> list: + """topic (sampler) → question (llm) → answer (llm) → score (expression).""" + return [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A", "B"]}), + LLMTextColumnConfig(name="question", prompt="Ask about {{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="answer", prompt="Answer {{ question }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="score", expr="{{ answer }}"), + ] + + +@pytest.fixture() +def simple_pipeline_strategies() -> dict[str, GenerationStrategy]: + return { + "topic": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "answer": GenerationStrategy.CELL_BY_CELL, + "score": GenerationStrategy.FULL_COLUMN, + } + + +@pytest.fixture() +def simple_graph( + simple_pipeline_configs: list, + simple_pipeline_strategies: dict[str, GenerationStrategy], +) -> ExecutionGraph: + return ExecutionGraph.create(simple_pipeline_configs, simple_pipeline_strategies) + + +# -- Graph construction tests ------------------------------------------------ + + +def test_build_basic_graph(simple_graph: ExecutionGraph) -> None: + assert simple_graph.columns == ["topic", "question", "answer", "score"] + assert simple_graph.get_upstream_columns("topic") == set() + assert simple_graph.get_upstream_columns("question") == {"topic"} + assert simple_graph.get_upstream_columns("answer") == {"question"} + assert simple_graph.get_upstream_columns("score") == {"answer"} + + +def test_get_downstream_columns(simple_graph: ExecutionGraph) -> None: + assert simple_graph.get_downstream_columns("topic") == {"question"} + assert simple_graph.get_downstream_columns("question") == {"answer"} + assert simple_graph.get_downstream_columns("answer") == {"score"} + assert simple_graph.get_downstream_columns("score") == set() + + +def test_strategy(simple_graph: ExecutionGraph) -> None: + assert simple_graph.get_strategy("topic") == GenerationStrategy.FULL_COLUMN + assert simple_graph.get_strategy("question") == GenerationStrategy.CELL_BY_CELL + + +def test_unknown_column_get_upstream_columns() -> None: + graph = ExecutionGraph() + assert graph.get_upstream_columns("nonexistent") == set() + + +def test_unknown_column_get_downstream_columns() -> None: + graph = ExecutionGraph() + assert graph.get_downstream_columns("nonexistent") == set() + + +# -- Side-effect resolution ------------------------------------------------- + + +def test_side_effect_column_resolution() -> None: + configs = [ + LLMTextColumnConfig( + name="summary", + prompt="Summarize", + model_alias=MODEL_ALIAS, + with_trace="last_message", + ), + ExpressionColumnConfig(name="trace_len", expr="{{ summary__trace }}"), + ] + strategies = { + "summary": GenerationStrategy.CELL_BY_CELL, + "trace_len": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("trace_len") == {"summary"} + assert graph.get_downstream_columns("summary") == {"trace_len"} + + +def test_reasoning_content_side_effect() -> None: + configs = [ + LLMTextColumnConfig( + name="answer", + prompt="Think step by step", + model_alias=MODEL_ALIAS, + extract_reasoning_content=True, + ), + ExpressionColumnConfig(name="reasoning", expr="{{ answer__reasoning_content }}"), + ] + strategies = { + "answer": GenerationStrategy.CELL_BY_CELL, + "reasoning": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("reasoning") == {"answer"} + + +def test_side_effect_name_collision_prefers_real_column() -> None: + configs = [ + LLMTextColumnConfig( + name="summary", + prompt="Summarize", + model_alias=MODEL_ALIAS, + with_trace="last_message", + ), + SamplerColumnConfig(name="summary__trace", sampler_type=SamplerType.CATEGORY, params={"values": ["OVERRIDE"]}), + ExpressionColumnConfig(name="trace_len", expr="{{ summary__trace }}"), + ] + strategies = { + "summary": GenerationStrategy.CELL_BY_CELL, + "summary__trace": GenerationStrategy.FULL_COLUMN, + "trace_len": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("trace_len") == {"summary__trace"} + assert graph.get_downstream_columns("summary__trace") == {"trace_len"} + assert graph.get_downstream_columns("summary") == set() + + +# -- Validation tests ------------------------------------------------------- + + +def test_circular_dependency_raises() -> None: + configs = [ + LLMTextColumnConfig(name="col_a", prompt="{{ col_b }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="col_b", prompt="{{ col_a }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "col_a": GenerationStrategy.CELL_BY_CELL, + "col_b": GenerationStrategy.CELL_BY_CELL, + } + with pytest.raises(DAGCircularDependencyError): + ExecutionGraph.create(configs, strategies) + + +def test_unknown_required_column_raises() -> None: + configs = [ + LLMTextColumnConfig(name="col_a", prompt="{{ nonexistent }}", model_alias=MODEL_ALIAS), + ] + strategies = {"col_a": GenerationStrategy.CELL_BY_CELL} + with pytest.raises(ValueError, match="not a known producer"): + ExecutionGraph.create(configs, strategies) + + +# -- Topological order ------------------------------------------------------ + + +def test_topological_order(simple_graph: ExecutionGraph) -> None: + order = simple_graph.get_topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["topic"] < idx["question"] + assert idx["question"] < idx["answer"] + assert idx["answer"] < idx["score"] + + +def test_parallel_columns_topological_order() -> None: + """Two independent columns after a shared root.""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]}), + LLMTextColumnConfig(name="branch_a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="branch_b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="merge", expr="{{ branch_a }} {{ branch_b }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "branch_a": GenerationStrategy.CELL_BY_CELL, + "branch_b": GenerationStrategy.CELL_BY_CELL, + "merge": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + order = graph.get_topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["seed"] < idx["branch_a"] + assert idx["seed"] < idx["branch_b"] + assert idx["branch_a"] < idx["merge"] + assert idx["branch_b"] < idx["merge"] + + +# -- Critical path ---------------------------------------------------------- + + +def test_get_longest_dependency_chain_empty_graph() -> None: + graph = ExecutionGraph() + assert graph.get_longest_dependency_chain() == [] + + +def test_get_longest_dependency_chain(simple_graph: ExecutionGraph) -> None: + path = simple_graph.get_longest_dependency_chain() + assert path == ["topic", "question", "answer", "score"] + + +def test_get_longest_dependency_chain_diamond() -> None: + """Diamond: seed → (a, b) → merge. Path is seed → a/b → merge (length 3).""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]}), + LLMTextColumnConfig(name="a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="merge", expr="{{ a }} {{ b }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "a": GenerationStrategy.CELL_BY_CELL, + "b": GenerationStrategy.CELL_BY_CELL, + "merge": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + path = graph.get_longest_dependency_chain() + + assert len(path) == 3 + assert path[0] == "seed" + assert path[-1] == "merge" + + +# -- Task count ------------------------------------------------------------- + + +def test_task_count(simple_graph: ExecutionGraph) -> None: + counts = simple_graph.compute_task_count(num_records=10, buffer_size=3) + + assert counts["topic"] == 4 # ceil(10/3) = 4 row groups + assert counts["question"] == 10 # cell-by-cell + assert counts["answer"] == 10 # cell-by-cell + assert counts["score"] == 4 # full-column + + +def test_task_count_exact_divisor(simple_graph: ExecutionGraph) -> None: + counts = simple_graph.compute_task_count(num_records=9, buffer_size=3) + + assert counts["topic"] == 3 + assert counts["question"] == 9 + + +@pytest.mark.parametrize("buffer_size", [0, -1]) +def test_task_count_invalid_buffer_size_raises(simple_graph: ExecutionGraph, buffer_size: int) -> None: + with pytest.raises(ValueError, match="buffer_size"): + simple_graph.compute_task_count(num_records=10, buffer_size=buffer_size) + + +def test_add_column_duplicate_raises() -> None: + graph = ExecutionGraph() + graph.add_column("col_a", GenerationStrategy.CELL_BY_CELL) + with pytest.raises(ValueError, match="already registered"): + graph.add_column("col_a", GenerationStrategy.FULL_COLUMN) + + +# -- Cell dependencies ------------------------------------------------------ + + +def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: + """question depends on topic (full-column); answer depends on question (cell-by-cell).""" + # answer[rg=0, row=2] should depend on question[rg=0, row=2] + deps = simple_graph.compute_cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) + assert deps == [SliceRef("question", 0, 2)] + + +def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: + """question depends on topic (full-column).""" + deps = simple_graph.compute_cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) + assert deps == [SliceRef("topic", 0, None)] + + +def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: + """topic has no upstream.""" + deps = simple_graph.compute_cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) + assert deps == [] + + +def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: + """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" + deps = simple_graph.compute_cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) + assert sorted(deps) == [SliceRef("answer", 0, 0), SliceRef("answer", 0, 1), SliceRef("answer", 0, 2)] + + +# -- Mermaid output ---------------------------------------------------------- + + +def test_to_mermaid(simple_graph: ExecutionGraph) -> None: + mermaid = simple_graph.to_mermaid() + + assert "graph TD" in mermaid + assert 'topic["topic [full_column]"]' in mermaid + assert 'question["question [cell_by_cell]"]' in mermaid + assert "topic --> question" in mermaid + assert "question --> answer" in mermaid + assert "answer --> score" in mermaid + + +# -- MultiColumnConfig ------------------------------------------------------- + + +def test_multi_column_config() -> None: + """Multi-column sampler config: all sub-columns share the same strategy.""" + multi = SamplerMultiColumnConfig( + columns=[ + SamplerColumnConfig(name="first_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Alice"]}), + SamplerColumnConfig(name="last_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Smith"]}), + ] + ) + configs = [multi] + strategies = { + "first_name": GenerationStrategy.FULL_COLUMN, + "last_name": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + + assert set(graph.columns) == {"first_name", "last_name"} + assert graph.get_upstream_columns("first_name") == set() + assert graph.get_upstream_columns("last_name") == set() + + +def test_multi_column_with_downstream_dependency() -> None: + multi = SamplerMultiColumnConfig( + columns=[ + SamplerColumnConfig(name="first_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Alice"]}), + SamplerColumnConfig(name="last_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Smith"]}), + ] + ) + greeting = LLMTextColumnConfig( + name="greeting", + prompt="Hello {{ first_name }} {{ last_name }}", + model_alias=MODEL_ALIAS, + ) + configs = [multi, greeting] + strategies = { + "first_name": GenerationStrategy.FULL_COLUMN, + "last_name": GenerationStrategy.FULL_COLUMN, + "greeting": GenerationStrategy.CELL_BY_CELL, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("greeting") == {"first_name", "last_name"} + + +# -- Validation column dependency ------------------------------------------- + + +def test_validation_column_dependency() -> None: + configs = [ + LLMCodeColumnConfig( + name="code", + prompt="Write code", + code_lang=CodeLang.PYTHON, + model_alias=MODEL_ALIAS, + ), + ValidationColumnConfig( + name="validation", + target_columns=["code"], + validator_type="code", + validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), + ), + ] + strategies = { + "code": GenerationStrategy.CELL_BY_CELL, + "validation": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("validation") == {"code"} + assert graph.get_downstream_columns("code") == {"validation"} + + +# -- Immutability tests ----------------------------------------------------- + + +def test_mutating_columns_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + cols = simple_graph.columns + cols.append("injected") + assert "injected" not in simple_graph.columns + + +def test_mutating_upstream_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + ups = simple_graph.get_upstream_columns("question") + ups.add("injected") + assert "injected" not in simple_graph.get_upstream_columns("question") + + +def test_mutating_downstream_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + downs = simple_graph.get_downstream_columns("topic") + downs.add("injected") + assert "injected" not in simple_graph.get_downstream_columns("topic") + + +def test_mutating_topological_order_does_not_affect_cache(simple_graph: ExecutionGraph) -> None: + order1 = simple_graph.get_topological_order() + order1.reverse() + order2 = simple_graph.get_topological_order() + assert order2[0] == "topic" + + +# -- Judge column dependency ------------------------------------------------ + + +def test_judge_column_dependency() -> None: + configs = [ + LLMTextColumnConfig(name="text", prompt="Write something", model_alias=MODEL_ALIAS), + LLMJudgeColumnConfig( + name="judge", + prompt="Judge {{ text }}", + scores=[Score(name="quality", description="Quality", options={0: "Bad", 1: "Good"})], + model_alias=MODEL_ALIAS, + ), + ] + strategies = { + "text": GenerationStrategy.CELL_BY_CELL, + "judge": GenerationStrategy.CELL_BY_CELL, + } + graph = ExecutionGraph.create(configs, strategies) + + assert graph.get_upstream_columns("judge") == {"text"} diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py new file mode 100644 index 000000000..5d5716213 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.engine.dataset_builders.utils.task_model import Task, TaskResult, TaskTrace + + +def test_task_is_frozen() -> None: + task = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + with pytest.raises(AttributeError): + task.column = "col_b" # type: ignore[misc] + + +def test_task_hashable_and_in_set() -> None: + t1 = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + t2 = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + t3 = Task(column="col_a", row_group=0, row_index=2, task_type="cell") + + assert t1 == t2 + assert t1 != t3 + assert hash(t1) == hash(t2) + + s: set[Task] = {t1, t2, t3} + assert len(s) == 2 + + +def test_task_batch_has_none_row_index() -> None: + task = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + assert task.row_index is None + + +@pytest.mark.parametrize( + "task_type", + ["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"], +) +def test_task_types(task_type: str) -> None: + task = Task(column="col", row_group=0, row_index=0, task_type=task_type) + assert task.task_type == task_type + + +def test_task_result_success() -> None: + task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") + result = TaskResult(task=task, status="success", output={"col_a": "value"}) + assert result.status == "success" + assert result.error is None + assert result.retryable is False + + +def test_task_result_error() -> None: + task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") + exc = ValueError("bad input") + result = TaskResult(task=task, status="error", error=exc, retryable=True) + assert result.status == "error" + assert result.error is exc + assert result.retryable is True + + +def test_task_trace_from_task() -> None: + task = Task(column="col_a", row_group=1, row_index=2, task_type="cell") + trace = TaskTrace.from_task(task) + + assert trace.column == "col_a" + assert trace.row_group == 1 + assert trace.row_index == 2 + assert trace.task_type == "cell" + assert trace.dispatched_at == 0.0 + assert trace.status == "" + + +def test_task_trace_mutable() -> None: + task = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + trace = TaskTrace.from_task(task) + + trace.dispatched_at = 1.0 + trace.slot_acquired_at = 1.5 + trace.completed_at = 2.0 + trace.status = "ok" + + assert trace.dispatched_at == 1.0 + assert trace.completed_at - trace.slot_acquired_at == pytest.approx(0.5) + + +def test_task_equality_differs_by_type() -> None: + t1 = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + t2 = Task(column="col_a", row_group=0, row_index=None, task_type="from_scratch") + assert t1 != t2 diff --git a/plans/346/async-generators-and-task-queue.md b/plans/346/async-generators-and-task-queue.md index a9b6b418e..0ac4e7ba7 100644 --- a/plans/346/async-generators-and-task-queue.md +++ b/plans/346/async-generators-and-task-queue.md @@ -1,7 +1,7 @@ # Plan: Async Generators & Task Queue Builder Created: 2026-02-20 -Status: Planning +Status: In Progress Issue: [#346](https://github.com/NVIDIA-NeMo/DataDesigner/issues/346) @@ -225,7 +225,7 @@ The graph is column-granularity only — no cell-level nodes — so it stays sma O(C²) edges worst-case) regardless of row count and avoids the barrier/checkpoint problems of a cell-level graph. -- [ ] `ExecutionGraph` class: +- [x] `ExecutionGraph` class: - Backing stores: `dict[str, set[str]]` column → upstream columns; `dict[str, GenerationStrategy]` column → generation strategy - `upstream(column: str) -> set[str]` — direct dependencies of a column @@ -238,7 +238,7 @@ graph. full-column columns (including from-scratch generators, which report `FULL_COLUMN`) produce `ceil(num_records / buffer_size)` tasks - `to_mermaid() -> str` — Mermaid diagram string; nodes are annotated with strategy type -- [ ] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: +- [x] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: - Input: the ordered list of `ColumnConfigT` / `MultiColumnConfig`, plus a pre-computed strategy map (available from generators at builder init time via `get_generation_strategy()`) - For each config, read `config.required_columns` → set of upstream column names @@ -247,7 +247,7 @@ graph. - For `MultiColumnConfig`, all sub-columns share the same dependencies - Validate: every required column must resolve to a known producer (including registered side-effect outputs), and the graph must be acyclic -- [ ] Unit tests for graph construction, validation, critical path, task count, and Mermaid output +- [x] Unit tests for graph construction, validation, critical path, task count, and Mermaid output **Files**: new module `engine/dataset_builders/utils/execution_graph.py`, tests @@ -257,7 +257,7 @@ A lightweight structure tracking which (column, row_group, row_index) tuples are done. Row indices are **local** to their row group (0-based within each group), matching the buffer manager's per-row-group addressing. -- [ ] `CompletionTracker` class: +- [x] `CompletionTracker` class: - Internal: `dict[int, dict[str, set[int]]]` mapping row_group → column → set of completed local row indices - `mark_complete(column: str, row_group: int, row_index: int)` / `mark_batch_complete(column: str, row_group: int, row_group_size: int)` - `is_ready(column: str, row_group: int, row_index: int, graph: ExecutionGraph) -> bool` — checks all upstream columns for that (row_group, row_index) @@ -266,8 +266,8 @@ matching the buffer manager's per-row-group addressing. `get_ready_tasks` skips dropped rows, in-flight tasks for dropped rows are ignored on completion - `is_row_group_complete(row_group: int, row_group_size: int, all_columns: list[str]) -> bool` — all non-dropped rows have all columns done; `row_group_size` is the original size, dropped rows (via `drop_row`) are excluded internally - `get_ready_tasks(graph: ExecutionGraph, row_groups, dispatched: set[Task]) -> list[Task]` — yields all currently dispatchable tasks, excluding dropped rows and already-dispatched/in-flight tasks; reads `graph.strategy(column)` to determine task granularity per column -- [ ] No locks needed: all access is from the single asyncio event loop thread -- [ ] Unit tests +- [x] No locks needed: all access is from the single asyncio event loop thread +- [x] Unit tests **Files**: new module `engine/dataset_builders/utils/completion_tracker.py`, tests @@ -275,19 +275,19 @@ matching the buffer manager's per-row-group addressing. Simple dataclass representing a unit of work. -- [ ] `Task` dataclass: +- [x] `Task` dataclass: - `column: str` - `row_group: int` - `row_index: int | None` (None for batch tasks) - `task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"]` -- [ ] `TaskResult` with status, output, error info -- [ ] `TaskTrace` dataclass (only instantiated when tracing is enabled): +- [x] `TaskResult` with status, output, error info +- [x] `TaskTrace` dataclass (only instantiated when tracing is enabled): - `column: str`, `row_group: int`, `row_index: int | None`, `task_type: str` - `dispatched_at: float` — `perf_counter()` when `create_task()` fires - `slot_acquired_at: float` — after execution semaphore acquired - `completed_at: float` — in `finally` block after generator returns - `status: str`, `error: str | None` -- [ ] Hashable so we can track dispatched/pending sets +- [x] Hashable so we can track dispatched/pending sets **Files**: new module `engine/dataset_builders/utils/task_model.py` — must be its own module since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference `Task`/`TaskResult`; @@ -432,15 +432,15 @@ Wire the new scheduler into `ColumnWiseDatasetBuilder`. Tests are added incrementally with each PR, not deferred to the end. **PR 1 (foundation) — unit tests**: -- [ ] Execution graph construction, validation, topological order, critical path -- [ ] Execution graph: side-effect output columns resolve correctly (e.g., column +- [x] Execution graph construction, validation, topological order, critical path +- [x] Execution graph: side-effect output columns resolve correctly (e.g., column depending on `summary__trace` maps to a dependency on the `summary` generator) -- [ ] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, +- [x] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, full-column, and from-scratch columns -- [ ] Execution graph: `task_count` and `to_mermaid` output -- [ ] Completion tracker: `mark_complete`, `is_complete`, `all_complete` -- [ ] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` -- [ ] Task model: hashability, equality, TaskResult, TaskTrace +- [x] Execution graph: `task_count` and `to_mermaid` output +- [x] Completion tracker: `mark_complete`, `is_complete`, `all_complete` +- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` +- [x] Task model: hashability, equality, TaskResult, TaskTrace **PR 2 (generators) — unit tests**: - [ ] Symmetric bridging: sync-only generator can be called via `agenerate`