From d6b3022711eb1489fac4a1681783b155758646fb Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 25 Feb 2026 23:58:41 -0300 Subject: [PATCH 01/17] feat: add ExecutionGraph, CompletionTracker, and Task model for async scheduler Add the foundational data structures for the async task-queue dataset builder (plan #346, PR 1/4): - ExecutionGraph: column-level static DAG with topological ordering, critical path, task counts, cell-dependency resolution, Mermaid output, and side-effect column mapping (__trace, __reasoning_content). - CompletionTracker: lightweight (column, row_group, row_index) completion state with row dropping and ready-task enumeration. - Task/TaskResult/TaskTrace: frozen hashable task dataclass, result container, and opt-in tracing record. All three are pure data structures with no side effects on the existing codebase. They live in new modules under engine/dataset_builders/utils/ and are only imported by code introduced in later PRs. 56 unit tests covering graph construction, validation, dependency resolution, completion tracking, row drops, and task model semantics. Refs #346 --- .../utils/completion_tracker.py | 111 +++++ .../dataset_builders/utils/execution_graph.py | 201 +++++++++ .../dataset_builders/utils/task_model.py | 52 +++ .../utils/test_completion_tracker.py | 257 ++++++++++++ .../utils/test_execution_graph.py | 381 ++++++++++++++++++ .../dataset_builders/utils/test_task_model.py | 87 ++++ plans/346/async-generators-and-task-queue.md | 36 +- 7 files changed, 1107 insertions(+), 18 deletions(-) create mode 100644 packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py create mode 100644 packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py create mode 100644 packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py create mode 100644 packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py new file mode 100644 index 000000000..e98dd962f --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING + +from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.dataset_builders.utils.task_model import Task + +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph + + +class CompletionTracker: + """Tracks which (column, row_group, row_index) tuples are done. + + All access is from the single asyncio event loop thread — no locks needed. + Row indices are local to their row group (0-based). + """ + + def __init__(self) -> None: + # row_group → column → set of completed local row indices + self._completed: dict[int, dict[str, set[int]]] = defaultdict(lambda: defaultdict(set)) + # row_group → set of dropped row indices + self._dropped: dict[int, set[int]] = defaultdict(set) + + def mark_complete(self, column: str, row_group: int, row_index: int) -> None: + self._completed[row_group][column].add(row_index) + + def mark_batch_complete(self, column: str, row_group: int, row_group_size: int) -> None: + self._completed[row_group][column] = set(range(row_group_size)) + + def is_complete(self, column: str, row_group: int, row_index: int) -> bool: + return row_index in self._completed.get(row_group, {}).get(column, set()) + + def all_complete(self, cells: list[tuple[str, int, int | None]]) -> bool: + """Check whether all the given (column, row_group, row_index) tuples are done. + + A ``row_index`` of ``None`` means the entire batch for that column must + be complete (i.e., that column key must exist in the row group's dict). + """ + for col, rg, ri in cells: + if ri is None: + if col not in self._completed.get(rg, {}): + return False + elif not self.is_complete(col, rg, ri): + return False + return True + + def drop_row(self, row_group: int, row_index: int) -> None: + self._dropped[row_group].add(row_index) + + def is_dropped(self, row_group: int, row_index: int) -> bool: + return row_index in self._dropped.get(row_group, set()) + + def is_row_group_complete( + self, + row_group: int, + row_group_size: int, + all_columns: list[str], + ) -> bool: + """All non-dropped rows have all columns done.""" + dropped = self._dropped.get(row_group, set()) + completed = self._completed.get(row_group, {}) + for ri in range(row_group_size): + if ri in dropped: + continue + for col in all_columns: + if ri not in completed.get(col, set()): + return False + return True + + def get_ready_tasks( + self, + graph: ExecutionGraph, + row_groups: list[tuple[int, int]], + dispatched: set[Task], + ) -> list[Task]: + """Return all currently dispatchable tasks. + + Excludes dropped rows and already-dispatched/in-flight tasks. + """ + ready: list[Task] = [] + for rg_id, rg_size in row_groups: + for col in graph.topological_order(): + strategy = graph.strategy(col) + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(rg_size): + if self.is_dropped(rg_id, ri): + continue + if self.is_complete(col, rg_id, ri): + continue + task = Task(column=col, row_group=rg_id, row_index=ri, task_type="cell") + if task in dispatched: + continue + deps = graph.cell_dependencies(col, rg_id, ri, rg_size) + if self.all_complete(deps): + ready.append(task) + else: + task = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") + if task in dispatched: + continue + # Check if already complete (batch-level) + if col in self._completed.get(rg_id, {}): + continue + deps = graph.cell_dependencies(col, rg_id, None, rg_size) + if self.all_complete(deps): + ready.append(task) + return ready diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py new file mode 100644 index 000000000..d6e4ccbcc --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from collections import defaultdict, deque +from dataclasses import dataclass, field + +from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.column_types import ColumnConfigT +from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError + +DatasetBuilderColumnConfigT = ColumnConfigT | MultiColumnConfig + + +@dataclass +class ExecutionGraph: + """Column-level static execution graph built from column configs. + + Nodes are columns (O(C)); edges are dependency relationships (O(C²) worst-case). + The graph is fixed for the lifetime of a run — runtime readiness is tracked + separately by ``CompletionTracker``. + """ + + _upstream: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) + _downstream: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) + _strategies: dict[str, GenerationStrategy] = field(default_factory=dict) + _side_effect_map: dict[str, str] = field(default_factory=dict) + _columns: list[str] = field(default_factory=list) + + def upstream(self, column: str) -> set[str]: + """Direct dependencies of *column*.""" + return self._upstream.get(column, set()) + + def downstream(self, column: str) -> set[str]: + """Columns that depend on *column*.""" + return self._downstream.get(column, set()) + + def strategy(self, column: str) -> GenerationStrategy: + return self._strategies[column] + + @property + def columns(self) -> list[str]: + """All column names in insertion order.""" + return list(self._columns) + + def topological_order(self) -> list[str]: + """Return a valid topological ordering of columns (Kahn's algorithm).""" + in_degree: dict[str, int] = {col: 0 for col in self._columns} + for col, deps in self._upstream.items(): + if col in in_degree: + in_degree[col] = len(deps) + + queue = deque(col for col, deg in in_degree.items() if deg == 0) + order: list[str] = [] + while queue: + col = queue.popleft() + order.append(col) + for child in self._downstream.get(col, set()): + if child in in_degree: + in_degree[child] -= 1 + if in_degree[child] == 0: + queue.append(child) + + if len(order) != len(self._columns): + raise DAGCircularDependencyError( + f"The execution graph contains cyclic dependencies. Resolved {len(order)}/{len(self._columns)} columns." + ) + return order + + def critical_path(self) -> list[str]: + """Longest dependency chain (by number of columns).""" + order = self.topological_order() + dist: dict[str, int] = {col: 0 for col in order} + pred: dict[str, str | None] = {col: None for col in order} + + for col in order: + for child in self._downstream.get(col, set()): + if dist[col] + 1 > dist[child]: + dist[child] = dist[col] + 1 + pred[child] = col + + end = max(order, key=lambda c: dist[c]) + path: list[str] = [] + cur: str | None = end + while cur is not None: + path.append(cur) + cur = pred[cur] + path.reverse() + return path + + def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: + """Exact task count per column before the run starts. + + Cell-by-cell columns produce ``num_records`` tasks. + Full-column columns (including from-scratch) produce ``ceil(num_records / buffer_size)`` tasks. + """ + num_row_groups = math.ceil(num_records / buffer_size) + counts: dict[str, int] = {} + for col in self._columns: + strat = self._strategies[col] + if strat == GenerationStrategy.CELL_BY_CELL: + counts[col] = num_records + else: + counts[col] = num_row_groups + return counts + + def cell_dependencies( + self, + column: str, + row_group: int, + row_index: int | None, + row_group_size: int, + ) -> list[tuple[str, int, int | None]]: + """Derive cell-level deps on demand from column-level DAG + strategy. + + Returns a list of ``(upstream_column, row_group, row_index)`` tuples + that must be complete before this task can run. + """ + deps: list[tuple[str, int, int | None]] = [] + for up_col in self.upstream(column): + up_strategy = self._strategies[up_col] + if up_strategy == GenerationStrategy.CELL_BY_CELL: + if row_index is not None: + deps.append((up_col, row_group, row_index)) + else: + for ri in range(row_group_size): + deps.append((up_col, row_group, ri)) + else: + deps.append((up_col, row_group, None)) + return deps + + def to_mermaid(self) -> str: + """Mermaid diagram string with strategy annotations.""" + lines = ["graph TD"] + for col in self._columns: + strat = self._strategies[col] + label = f"{col} [{strat.value}]" + lines.append(f' {col}["{label}"]') + for col in self._columns: + for dep in sorted(self._upstream.get(col, set())): + lines.append(f" {dep} --> {col}") + return "\n".join(lines) + + +def build_execution_graph( + column_configs: list[DatasetBuilderColumnConfigT], + strategies: dict[str, GenerationStrategy], +) -> ExecutionGraph: + """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. + + Args: + column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. + strategies: Map of column name → ``GenerationStrategy``, obtained from + each generator's ``get_generation_strategy()``. + """ + graph = ExecutionGraph() + + # First pass: register all columns, strategies, and side-effect mappings + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + graph._columns.append(name) + graph._strategies[name] = strategies[name] + + for se_col in sub.side_effect_columns: + graph._side_effect_map[se_col] = name + + known_columns = set(graph._columns) + + # Second pass: build edges + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + for req in sub.required_columns: + resolved = graph._side_effect_map.get(req, req) + if resolved not in known_columns: + raise ValueError( + f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." + ) + if resolved == name: + continue # skip self-dependency + graph._upstream[name].add(resolved) + graph._downstream[resolved].add(name) + + # Validate acyclicity + graph.topological_order() + + return graph diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py new file mode 100644 index 000000000..c43c9362c --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + + +@dataclass(frozen=True) +class Task: + """A unit of work for the async scheduler.""" + + column: str + row_group: int + row_index: int | None # None for batch/full-column tasks + task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] + + +@dataclass +class TaskResult: + """Outcome of a completed task.""" + + task: Task + status: Literal["success", "error"] + output: Any = None + error: Exception | None = None + retryable: bool = False + + +@dataclass +class TaskTrace: + """Timing trace for a single task. Only created when tracing is enabled.""" + + column: str + row_group: int + row_index: int | None + task_type: str + dispatched_at: float = 0.0 + slot_acquired_at: float = 0.0 + completed_at: float = 0.0 + status: str = "" + error: str | None = None + + @staticmethod + def from_task(task: Task) -> TaskTrace: + return TaskTrace( + column=task.column, + row_group=task.row_group, + row_index=task.row_index, + task_type=task.task_type, + ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py new file mode 100644 index 000000000..3fd904a52 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# Helpers to build minimal graphs without real column configs +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMTextColumnConfig, + SamplerColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph, build_execution_graph +from data_designer.engine.dataset_builders.utils.task_model import Task + +MODEL_ALIAS = "stub" + + +def _build_simple_graph() -> ExecutionGraph: + """topic (full-column) → question (cell-by-cell) → score (full-column).""" + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="question", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="score", expr="{{ question }}"), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "score": GenerationStrategy.FULL_COLUMN, + } + return build_execution_graph(configs, strategies) + + +# -- mark_complete / is_complete ------------------------------------------- + + +def test_mark_and_check_complete() -> None: + tracker = CompletionTracker() + tracker.mark_complete("col_a", row_group=0, row_index=0) + + assert tracker.is_complete("col_a", 0, 0) + assert not tracker.is_complete("col_a", 0, 1) + assert not tracker.is_complete("col_a", 1, 0) + assert not tracker.is_complete("col_b", 0, 0) + + +def test_mark_batch_complete() -> None: + tracker = CompletionTracker() + tracker.mark_batch_complete("col_a", row_group=0, row_group_size=3) + + assert tracker.is_complete("col_a", 0, 0) + assert tracker.is_complete("col_a", 0, 1) + assert tracker.is_complete("col_a", 0, 2) + assert not tracker.is_complete("col_a", 0, 3) + + +# -- all_complete ----------------------------------------------------------- + + +def test_all_complete_cell_level() -> None: + tracker = CompletionTracker() + tracker.mark_complete("col_a", 0, 0) + tracker.mark_complete("col_a", 0, 1) + + assert tracker.all_complete([("col_a", 0, 0), ("col_a", 0, 1)]) + assert not tracker.all_complete([("col_a", 0, 0), ("col_a", 0, 2)]) + + +def test_all_complete_batch_level() -> None: + tracker = CompletionTracker() + tracker.mark_batch_complete("col_a", 0, 3) + + assert tracker.all_complete([("col_a", 0, None)]) + + +def test_all_complete_batch_not_present() -> None: + tracker = CompletionTracker() + assert not tracker.all_complete([("col_a", 0, None)]) + + +def test_all_complete_empty_list() -> None: + tracker = CompletionTracker() + assert tracker.all_complete([]) + + +# -- drop_row / is_dropped ------------------------------------------------- + + +def test_drop_row() -> None: + tracker = CompletionTracker() + tracker.drop_row(row_group=0, row_index=2) + + assert tracker.is_dropped(0, 2) + assert not tracker.is_dropped(0, 0) + assert not tracker.is_dropped(1, 2) + + +# -- is_row_group_complete -------------------------------------------------- + + +def test_row_group_complete() -> None: + tracker = CompletionTracker() + tracker.mark_batch_complete("col_a", 0, 3) + tracker.mark_batch_complete("col_b", 0, 3) + + assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_incomplete() -> None: + tracker = CompletionTracker() + tracker.mark_batch_complete("col_a", 0, 3) + + assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_complete_with_dropped_rows() -> None: + tracker = CompletionTracker() + tracker.mark_complete("col_a", 0, 0) + tracker.mark_complete("col_a", 0, 2) + tracker.mark_complete("col_b", 0, 0) + tracker.mark_complete("col_b", 0, 2) + tracker.drop_row(0, 1) # row 1 is dropped + + assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +def test_row_group_not_complete_missing_non_dropped() -> None: + tracker = CompletionTracker() + tracker.mark_complete("col_a", 0, 0) + tracker.mark_complete("col_b", 0, 0) + tracker.drop_row(0, 1) + # row 2 is not dropped and not complete + + assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) + + +# -- get_ready_tasks -------------------------------------------------------- + + +def test_get_ready_tasks_seeds_first() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + # Only the seed column should be ready (no upstream) + assert len(ready) == 1 + assert ready[0].column == "topic" + assert ready[0].task_type == "batch" + + +def test_get_ready_tasks_after_seed_complete() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + # All question cells should be ready (topic is done) + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 3 + assert all(t.task_type == "cell" for t in question_tasks) + assert {t.row_index for t in question_tasks} == {0, 1, 2} + + +def test_get_ready_tasks_skips_dispatched() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + + ready1 = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + dispatched.update(ready1) + + ready2 = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + assert len(ready2) == 0 + + +def test_get_ready_tasks_skips_dropped_rows() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + tracker.drop_row(0, 1) + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 2 + assert {t.row_index for t in question_tasks} == {0, 2} + + +def test_get_ready_tasks_full_column_waits_for_all_cells() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + tracker.mark_complete("question", 0, 0) + tracker.mark_complete("question", 0, 1) + # question[0,2] not done yet + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 0 # score waits for all question rows + + +def test_get_ready_tasks_full_column_ready_when_all_cells_done() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + for ri in range(3): + tracker.mark_complete("question", 0, ri) + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 1 + assert score_tasks[0].task_type == "batch" + + +def test_get_ready_tasks_multiple_row_groups() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + # Both row groups have topic done + tracker.mark_batch_complete("topic", 0, 3) + tracker.mark_batch_complete("topic", 1, 2) + + ready = tracker.get_ready_tasks(graph, [(0, 3), (1, 2)], dispatched) + + question_tasks = [t for t in ready if t.column == "question"] + assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 + + +def test_get_ready_tasks_skips_already_complete_batch() -> None: + graph = _build_simple_graph() + tracker = CompletionTracker() + dispatched: set[Task] = set() + + tracker.mark_batch_complete("topic", 0, 3) + + ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + + # topic is already complete, should not be in ready tasks + topic_tasks = [t for t in ready if t.column == "topic"] + assert len(topic_tasks) == 0 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py new file mode 100644 index 000000000..8a714fb16 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMCodeColumnConfig, + LLMJudgeColumnConfig, + LLMTextColumnConfig, + SamplerColumnConfig, + Score, + ValidationColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.config.utils.code_lang import CodeLang +from data_designer.config.validator_params import CodeValidatorParams +from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig +from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError +from data_designer.engine.dataset_builders.utils.execution_graph import ( + ExecutionGraph, + build_execution_graph, +) + +MODEL_ALIAS = "stub-model-alias" + + +# -- Fixtures ---------------------------------------------------------------- + + +@pytest.fixture() +def simple_pipeline_configs() -> list: + """topic (sampler) → question (llm) → answer (llm) → score (expression).""" + return [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A", "B"]}), + LLMTextColumnConfig(name="question", prompt="Ask about {{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="answer", prompt="Answer {{ question }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="score", expr="{{ answer }}"), + ] + + +@pytest.fixture() +def simple_pipeline_strategies() -> dict[str, GenerationStrategy]: + return { + "topic": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "answer": GenerationStrategy.CELL_BY_CELL, + "score": GenerationStrategy.FULL_COLUMN, + } + + +@pytest.fixture() +def simple_graph( + simple_pipeline_configs: list, + simple_pipeline_strategies: dict[str, GenerationStrategy], +) -> ExecutionGraph: + return build_execution_graph(simple_pipeline_configs, simple_pipeline_strategies) + + +# -- Graph construction tests ------------------------------------------------ + + +def test_build_basic_graph(simple_graph: ExecutionGraph) -> None: + assert simple_graph.columns == ["topic", "question", "answer", "score"] + assert simple_graph.upstream("topic") == set() + assert simple_graph.upstream("question") == {"topic"} + assert simple_graph.upstream("answer") == {"question"} + assert simple_graph.upstream("score") == {"answer"} + + +def test_downstream(simple_graph: ExecutionGraph) -> None: + assert simple_graph.downstream("topic") == {"question"} + assert simple_graph.downstream("question") == {"answer"} + assert simple_graph.downstream("answer") == {"score"} + assert simple_graph.downstream("score") == set() + + +def test_strategy(simple_graph: ExecutionGraph) -> None: + assert simple_graph.strategy("topic") == GenerationStrategy.FULL_COLUMN + assert simple_graph.strategy("question") == GenerationStrategy.CELL_BY_CELL + + +def test_unknown_column_upstream() -> None: + graph = ExecutionGraph() + assert graph.upstream("nonexistent") == set() + + +def test_unknown_column_downstream() -> None: + graph = ExecutionGraph() + assert graph.downstream("nonexistent") == set() + + +# -- Side-effect resolution ------------------------------------------------- + + +def test_side_effect_column_resolution() -> None: + configs = [ + LLMTextColumnConfig( + name="summary", + prompt="Summarize", + model_alias=MODEL_ALIAS, + with_trace="last_message", + ), + ExpressionColumnConfig(name="trace_len", expr="{{ summary__trace }}"), + ] + strategies = { + "summary": GenerationStrategy.CELL_BY_CELL, + "trace_len": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("trace_len") == {"summary"} + assert graph.downstream("summary") == {"trace_len"} + + +def test_reasoning_content_side_effect() -> None: + configs = [ + LLMTextColumnConfig( + name="answer", + prompt="Think step by step", + model_alias=MODEL_ALIAS, + extract_reasoning_content=True, + ), + ExpressionColumnConfig(name="reasoning", expr="{{ answer__reasoning_content }}"), + ] + strategies = { + "answer": GenerationStrategy.CELL_BY_CELL, + "reasoning": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("reasoning") == {"answer"} + + +# -- Validation tests ------------------------------------------------------- + + +def test_circular_dependency_raises() -> None: + configs = [ + LLMTextColumnConfig(name="col_a", prompt="{{ col_b }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="col_b", prompt="{{ col_a }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "col_a": GenerationStrategy.CELL_BY_CELL, + "col_b": GenerationStrategy.CELL_BY_CELL, + } + with pytest.raises(DAGCircularDependencyError): + build_execution_graph(configs, strategies) + + +def test_unknown_required_column_raises() -> None: + configs = [ + LLMTextColumnConfig(name="col_a", prompt="{{ nonexistent }}", model_alias=MODEL_ALIAS), + ] + strategies = {"col_a": GenerationStrategy.CELL_BY_CELL} + with pytest.raises(ValueError, match="not a known producer"): + build_execution_graph(configs, strategies) + + +# -- Topological order ------------------------------------------------------ + + +def test_topological_order(simple_graph: ExecutionGraph) -> None: + order = simple_graph.topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["topic"] < idx["question"] + assert idx["question"] < idx["answer"] + assert idx["answer"] < idx["score"] + + +def test_parallel_columns_topological_order() -> None: + """Two independent columns after a shared root.""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]}), + LLMTextColumnConfig(name="branch_a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="branch_b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="merge", expr="{{ branch_a }} {{ branch_b }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "branch_a": GenerationStrategy.CELL_BY_CELL, + "branch_b": GenerationStrategy.CELL_BY_CELL, + "merge": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + order = graph.topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["seed"] < idx["branch_a"] + assert idx["seed"] < idx["branch_b"] + assert idx["branch_a"] < idx["merge"] + assert idx["branch_b"] < idx["merge"] + + +# -- Critical path ---------------------------------------------------------- + + +def test_critical_path(simple_graph: ExecutionGraph) -> None: + path = simple_graph.critical_path() + assert path == ["topic", "question", "answer", "score"] + + +def test_critical_path_diamond() -> None: + """Diamond: seed → (a, b) → merge. Path is seed → a/b → merge (length 3).""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]}), + LLMTextColumnConfig(name="a", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="b", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="merge", expr="{{ a }} {{ b }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "a": GenerationStrategy.CELL_BY_CELL, + "b": GenerationStrategy.CELL_BY_CELL, + "merge": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + path = graph.critical_path() + + assert len(path) == 3 + assert path[0] == "seed" + assert path[-1] == "merge" + + +# -- Task count ------------------------------------------------------------- + + +def test_task_count(simple_graph: ExecutionGraph) -> None: + counts = simple_graph.task_count(num_records=10, buffer_size=3) + + assert counts["topic"] == 4 # ceil(10/3) = 4 row groups + assert counts["question"] == 10 # cell-by-cell + assert counts["answer"] == 10 # cell-by-cell + assert counts["score"] == 4 # full-column + + +def test_task_count_exact_divisor(simple_graph: ExecutionGraph) -> None: + counts = simple_graph.task_count(num_records=9, buffer_size=3) + + assert counts["topic"] == 3 + assert counts["question"] == 9 + + +# -- Cell dependencies ------------------------------------------------------ + + +def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: + """question depends on topic (full-column); answer depends on question (cell-by-cell).""" + # answer[rg=0, row=2] should depend on question[rg=0, row=2] + deps = simple_graph.cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) + assert deps == [("question", 0, 2)] + + +def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: + """question depends on topic (full-column).""" + deps = simple_graph.cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) + assert deps == [("topic", 0, None)] + + +def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: + """topic has no upstream.""" + deps = simple_graph.cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) + assert deps == [] + + +def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: + """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" + deps = simple_graph.cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) + assert sorted(deps) == [("answer", 0, 0), ("answer", 0, 1), ("answer", 0, 2)] + + +# -- Mermaid output ---------------------------------------------------------- + + +def test_to_mermaid(simple_graph: ExecutionGraph) -> None: + mermaid = simple_graph.to_mermaid() + + assert "graph TD" in mermaid + assert 'topic["topic [full_column]"]' in mermaid + assert 'question["question [cell_by_cell]"]' in mermaid + assert "topic --> question" in mermaid + assert "question --> answer" in mermaid + assert "answer --> score" in mermaid + + +# -- MultiColumnConfig ------------------------------------------------------- + + +def test_multi_column_config() -> None: + """Multi-column sampler config: all sub-columns share the same strategy.""" + multi = SamplerMultiColumnConfig( + columns=[ + SamplerColumnConfig(name="first_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Alice"]}), + SamplerColumnConfig(name="last_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Smith"]}), + ] + ) + configs = [multi] + strategies = { + "first_name": GenerationStrategy.FULL_COLUMN, + "last_name": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + + assert set(graph.columns) == {"first_name", "last_name"} + assert graph.upstream("first_name") == set() + assert graph.upstream("last_name") == set() + + +def test_multi_column_with_downstream_dependency() -> None: + multi = SamplerMultiColumnConfig( + columns=[ + SamplerColumnConfig(name="first_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Alice"]}), + SamplerColumnConfig(name="last_name", sampler_type=SamplerType.CATEGORY, params={"values": ["Smith"]}), + ] + ) + greeting = LLMTextColumnConfig( + name="greeting", + prompt="Hello {{ first_name }} {{ last_name }}", + model_alias=MODEL_ALIAS, + ) + configs = [multi, greeting] + strategies = { + "first_name": GenerationStrategy.FULL_COLUMN, + "last_name": GenerationStrategy.FULL_COLUMN, + "greeting": GenerationStrategy.CELL_BY_CELL, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("greeting") == {"first_name", "last_name"} + + +# -- Validation column dependency ------------------------------------------- + + +def test_validation_column_dependency() -> None: + configs = [ + LLMCodeColumnConfig( + name="code", + prompt="Write code", + code_lang=CodeLang.PYTHON, + model_alias=MODEL_ALIAS, + ), + ValidationColumnConfig( + name="validation", + target_columns=["code"], + validator_type="code", + validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), + ), + ] + strategies = { + "code": GenerationStrategy.CELL_BY_CELL, + "validation": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("validation") == {"code"} + assert graph.downstream("code") == {"validation"} + + +# -- Judge column dependency ------------------------------------------------ + + +def test_judge_column_dependency() -> None: + configs = [ + LLMTextColumnConfig(name="text", prompt="Write something", model_alias=MODEL_ALIAS), + LLMJudgeColumnConfig( + name="judge", + prompt="Judge {{ text }}", + scores=[Score(name="quality", description="Quality", options={0: "Bad", 1: "Good"})], + model_alias=MODEL_ALIAS, + ), + ] + strategies = { + "text": GenerationStrategy.CELL_BY_CELL, + "judge": GenerationStrategy.CELL_BY_CELL, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("judge") == {"text"} diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py new file mode 100644 index 000000000..b18f76b98 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from data_designer.engine.dataset_builders.utils.task_model import Task, TaskResult, TaskTrace + + +def test_task_is_frozen() -> None: + task = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + with pytest.raises(AttributeError): + task.column = "col_b" # type: ignore[misc] + + +def test_task_hashable_and_in_set() -> None: + t1 = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + t2 = Task(column="col_a", row_group=0, row_index=1, task_type="cell") + t3 = Task(column="col_a", row_group=0, row_index=2, task_type="cell") + + assert t1 == t2 + assert t1 != t3 + assert hash(t1) == hash(t2) + + s: set[Task] = {t1, t2, t3} + assert len(s) == 2 + + +def test_task_batch_has_none_row_index() -> None: + task = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + assert task.row_index is None + + +@pytest.mark.parametrize( + "task_type", + ["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"], +) +def test_task_types(task_type: str) -> None: + task = Task(column="col", row_group=0, row_index=0, task_type=task_type) + assert task.task_type == task_type + + +def test_task_result_success() -> None: + task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") + result = TaskResult(task=task, status="success", output={"col_a": "value"}) + assert result.status == "success" + assert result.error is None + assert result.retryable is False + + +def test_task_result_error() -> None: + task = Task(column="col_a", row_group=0, row_index=0, task_type="cell") + exc = ValueError("bad input") + result = TaskResult(task=task, status="error", error=exc, retryable=True) + assert result.status == "error" + assert result.error is exc + assert result.retryable is True + + +def test_task_trace_from_task() -> None: + task = Task(column="col_a", row_group=1, row_index=2, task_type="cell") + trace = TaskTrace.from_task(task) + + assert trace.column == "col_a" + assert trace.row_group == 1 + assert trace.row_index == 2 + assert trace.task_type == "cell" + assert trace.dispatched_at == 0.0 + assert trace.status == "" + + +def test_task_trace_mutable() -> None: + task = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + trace = TaskTrace.from_task(task) + + trace.dispatched_at = 1.0 + trace.slot_acquired_at = 1.5 + trace.completed_at = 2.0 + trace.status = "ok" + + assert trace.dispatched_at == 1.0 + assert trace.completed_at - trace.slot_acquired_at == pytest.approx(0.5) + + +def test_task_equality_differs_by_type() -> None: + t1 = Task(column="col_a", row_group=0, row_index=None, task_type="batch") + t2 = Task(column="col_a", row_group=0, row_index=None, task_type="from_scratch") + assert t1 != t2 diff --git a/plans/346/async-generators-and-task-queue.md b/plans/346/async-generators-and-task-queue.md index a9b6b418e..0ac4e7ba7 100644 --- a/plans/346/async-generators-and-task-queue.md +++ b/plans/346/async-generators-and-task-queue.md @@ -1,7 +1,7 @@ # Plan: Async Generators & Task Queue Builder Created: 2026-02-20 -Status: Planning +Status: In Progress Issue: [#346](https://github.com/NVIDIA-NeMo/DataDesigner/issues/346) @@ -225,7 +225,7 @@ The graph is column-granularity only — no cell-level nodes — so it stays sma O(C²) edges worst-case) regardless of row count and avoids the barrier/checkpoint problems of a cell-level graph. -- [ ] `ExecutionGraph` class: +- [x] `ExecutionGraph` class: - Backing stores: `dict[str, set[str]]` column → upstream columns; `dict[str, GenerationStrategy]` column → generation strategy - `upstream(column: str) -> set[str]` — direct dependencies of a column @@ -238,7 +238,7 @@ graph. full-column columns (including from-scratch generators, which report `FULL_COLUMN`) produce `ceil(num_records / buffer_size)` tasks - `to_mermaid() -> str` — Mermaid diagram string; nodes are annotated with strategy type -- [ ] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: +- [x] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: - Input: the ordered list of `ColumnConfigT` / `MultiColumnConfig`, plus a pre-computed strategy map (available from generators at builder init time via `get_generation_strategy()`) - For each config, read `config.required_columns` → set of upstream column names @@ -247,7 +247,7 @@ graph. - For `MultiColumnConfig`, all sub-columns share the same dependencies - Validate: every required column must resolve to a known producer (including registered side-effect outputs), and the graph must be acyclic -- [ ] Unit tests for graph construction, validation, critical path, task count, and Mermaid output +- [x] Unit tests for graph construction, validation, critical path, task count, and Mermaid output **Files**: new module `engine/dataset_builders/utils/execution_graph.py`, tests @@ -257,7 +257,7 @@ A lightweight structure tracking which (column, row_group, row_index) tuples are done. Row indices are **local** to their row group (0-based within each group), matching the buffer manager's per-row-group addressing. -- [ ] `CompletionTracker` class: +- [x] `CompletionTracker` class: - Internal: `dict[int, dict[str, set[int]]]` mapping row_group → column → set of completed local row indices - `mark_complete(column: str, row_group: int, row_index: int)` / `mark_batch_complete(column: str, row_group: int, row_group_size: int)` - `is_ready(column: str, row_group: int, row_index: int, graph: ExecutionGraph) -> bool` — checks all upstream columns for that (row_group, row_index) @@ -266,8 +266,8 @@ matching the buffer manager's per-row-group addressing. `get_ready_tasks` skips dropped rows, in-flight tasks for dropped rows are ignored on completion - `is_row_group_complete(row_group: int, row_group_size: int, all_columns: list[str]) -> bool` — all non-dropped rows have all columns done; `row_group_size` is the original size, dropped rows (via `drop_row`) are excluded internally - `get_ready_tasks(graph: ExecutionGraph, row_groups, dispatched: set[Task]) -> list[Task]` — yields all currently dispatchable tasks, excluding dropped rows and already-dispatched/in-flight tasks; reads `graph.strategy(column)` to determine task granularity per column -- [ ] No locks needed: all access is from the single asyncio event loop thread -- [ ] Unit tests +- [x] No locks needed: all access is from the single asyncio event loop thread +- [x] Unit tests **Files**: new module `engine/dataset_builders/utils/completion_tracker.py`, tests @@ -275,19 +275,19 @@ matching the buffer manager's per-row-group addressing. Simple dataclass representing a unit of work. -- [ ] `Task` dataclass: +- [x] `Task` dataclass: - `column: str` - `row_group: int` - `row_index: int | None` (None for batch tasks) - `task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"]` -- [ ] `TaskResult` with status, output, error info -- [ ] `TaskTrace` dataclass (only instantiated when tracing is enabled): +- [x] `TaskResult` with status, output, error info +- [x] `TaskTrace` dataclass (only instantiated when tracing is enabled): - `column: str`, `row_group: int`, `row_index: int | None`, `task_type: str` - `dispatched_at: float` — `perf_counter()` when `create_task()` fires - `slot_acquired_at: float` — after execution semaphore acquired - `completed_at: float` — in `finally` block after generator returns - `status: str`, `error: str | None` -- [ ] Hashable so we can track dispatched/pending sets +- [x] Hashable so we can track dispatched/pending sets **Files**: new module `engine/dataset_builders/utils/task_model.py` — must be its own module since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference `Task`/`TaskResult`; @@ -432,15 +432,15 @@ Wire the new scheduler into `ColumnWiseDatasetBuilder`. Tests are added incrementally with each PR, not deferred to the end. **PR 1 (foundation) — unit tests**: -- [ ] Execution graph construction, validation, topological order, critical path -- [ ] Execution graph: side-effect output columns resolve correctly (e.g., column +- [x] Execution graph construction, validation, topological order, critical path +- [x] Execution graph: side-effect output columns resolve correctly (e.g., column depending on `summary__trace` maps to a dependency on the `summary` generator) -- [ ] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, +- [x] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, full-column, and from-scratch columns -- [ ] Execution graph: `task_count` and `to_mermaid` output -- [ ] Completion tracker: `mark_complete`, `is_complete`, `all_complete` -- [ ] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` -- [ ] Task model: hashability, equality, TaskResult, TaskTrace +- [x] Execution graph: `task_count` and `to_mermaid` output +- [x] Completion tracker: `mark_complete`, `is_complete`, `all_complete` +- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` +- [x] Task model: hashability, equality, TaskResult, TaskTrace **PR 2 (generators) — unit tests**: - [ ] Symmetric bridging: sync-only generator can be called via `agenerate` From 969a87f9a068f5a9afb32878507fcac9b43ce737 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 26 Feb 2026 18:37:24 -0300 Subject: [PATCH 02/17] refactor: extract readiness helpers and cache topological order Add `is_ready` and `is_batch_ready` methods to CompletionTracker to simplify `ready_tasks`. Cache topological order in ExecutionGraph since the graph is immutable after construction. Move DatasetBuilderColumnConfigT type alias to multi_column_configs. Fix license header years. --- .../utils/completion_tracker.py | 34 +++++++++++++++---- .../dataset_builders/utils/execution_graph.py | 24 +++++++++---- .../dataset_builders/utils/task_model.py | 2 +- .../utils/test_completion_tracker.py | 2 +- .../utils/test_execution_graph.py | 2 +- .../dataset_builders/utils/test_task_model.py | 2 +- 6 files changed, 49 insertions(+), 17 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index e98dd962f..a9ec858dd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -49,6 +49,31 @@ def all_complete(self, cells: list[tuple[str, int, int | None]]) -> bool: return False return True + def is_ready( + self, + column: str, + row_group: int, + row_index: int, + graph: ExecutionGraph, + row_group_size: int, + ) -> bool: + """Check if all upstream columns are done for this (column, row_group, row_index).""" + deps = graph.cell_dependencies(column, row_group, row_index, row_group_size) + return self.all_complete(deps) + + def is_batch_ready( + self, + column: str, + row_group: int, + row_group_size: int, + graph: ExecutionGraph, + ) -> bool: + """Check if all upstream columns are done for all non-dropped rows in the row group.""" + deps = graph.cell_dependencies(column, row_group, None, row_group_size) + # Dropped rows don't need their upstream cells complete + deps = [(c, rg, ri) for c, rg, ri in deps if ri is None or not self.is_dropped(rg, ri)] + return self.all_complete(deps) + def drop_row(self, row_group: int, row_index: int) -> None: self._dropped[row_group].add(row_index) @@ -95,17 +120,14 @@ def get_ready_tasks( task = Task(column=col, row_group=rg_id, row_index=ri, task_type="cell") if task in dispatched: continue - deps = graph.cell_dependencies(col, rg_id, ri, rg_size) - if self.all_complete(deps): + if self.is_ready(col, rg_id, ri, graph, rg_size): ready.append(task) else: task = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") if task in dispatched: continue - # Check if already complete (batch-level) if col in self._completed.get(rg_id, {}): continue - deps = graph.cell_dependencies(col, rg_id, None, rg_size) - if self.all_complete(deps): + if self.is_batch_ready(col, rg_id, rg_size, graph): ready.append(task) return ready diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index d6e4ccbcc..9dca1d26f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -8,12 +8,12 @@ from dataclasses import dataclass, field from data_designer.config.column_configs import GenerationStrategy -from data_designer.config.column_types import ColumnConfigT -from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.multi_column_configs import ( + DatasetBuilderColumnConfigT, + MultiColumnConfig, +) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -DatasetBuilderColumnConfigT = ColumnConfigT | MultiColumnConfig - @dataclass class ExecutionGraph: @@ -29,6 +29,7 @@ class ExecutionGraph: _strategies: dict[str, GenerationStrategy] = field(default_factory=dict) _side_effect_map: dict[str, str] = field(default_factory=dict) _columns: list[str] = field(default_factory=list) + _topological_order_cache: list[str] | None = field(default=None, repr=False) def upstream(self, column: str) -> set[str]: """Direct dependencies of *column*.""" @@ -47,7 +48,14 @@ def columns(self) -> list[str]: return list(self._columns) def topological_order(self) -> list[str]: - """Return a valid topological ordering of columns (Kahn's algorithm).""" + """Return a valid topological ordering of columns (Kahn's algorithm). + + Result is cached after first successful computation since the graph is + immutable after construction. + """ + if self._topological_order_cache is not None: + return list(self._topological_order_cache) + in_degree: dict[str, int] = {col: 0 for col in self._columns} for col, deps in self._upstream.items(): if col in in_degree: @@ -68,7 +76,9 @@ def topological_order(self) -> list[str]: raise DAGCircularDependencyError( f"The execution graph contains cyclic dependencies. Resolved {len(order)}/{len(self._columns)} columns." ) - return order + + self._topological_order_cache = order + return list(order) def critical_path(self) -> list[str]: """Longest dependency chain (by number of columns).""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index c43c9362c..6929dbfbe 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 3fd904a52..77d836154 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 8a714fb16..921752140 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py index b18f76b98..f9e2f9fc4 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import pytest From 0aa7d54c475e0d010eec51892c882e9d2a5761f8 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 27 Feb 2026 15:37:38 -0300 Subject: [PATCH 03/17] refactor: address PR review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename all_complete → is_all_complete for boolean method convention - Add ColumnName, RowGroup, RowIndex type aliases for readability - Add public mutation API to ExecutionGraph (add_column, add_edge, set_side_effect, resolve_side_effect) and rewrite build_execution_graph to use it instead of private attributes - Change TaskTrace.from_task from @staticmethod to @classmethod --- .../utils/completion_tracker.py | 38 +++++------ .../dataset_builders/utils/execution_graph.py | 67 ++++++++++++------- .../dataset_builders/utils/task_model.py | 24 ++++--- .../utils/test_completion_tracker.py | 12 ++-- 4 files changed, 81 insertions(+), 60 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index a9ec858dd..4d7f4f909 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import Task +from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroup, RowIndex, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -22,20 +22,20 @@ class CompletionTracker: def __init__(self) -> None: # row_group → column → set of completed local row indices - self._completed: dict[int, dict[str, set[int]]] = defaultdict(lambda: defaultdict(set)) + self._completed: dict[RowGroup, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices - self._dropped: dict[int, set[int]] = defaultdict(set) + self._dropped: dict[RowGroup, set[RowIndex]] = defaultdict(set) - def mark_complete(self, column: str, row_group: int, row_index: int) -> None: + def mark_complete(self, column: ColumnName, row_group: RowGroup, row_index: RowIndex) -> None: self._completed[row_group][column].add(row_index) - def mark_batch_complete(self, column: str, row_group: int, row_group_size: int) -> None: + def mark_batch_complete(self, column: ColumnName, row_group: RowGroup, row_group_size: int) -> None: self._completed[row_group][column] = set(range(row_group_size)) - def is_complete(self, column: str, row_group: int, row_index: int) -> bool: + def is_complete(self, column: ColumnName, row_group: RowGroup, row_index: RowIndex) -> bool: return row_index in self._completed.get(row_group, {}).get(column, set()) - def all_complete(self, cells: list[tuple[str, int, int | None]]) -> bool: + def is_all_complete(self, cells: list[tuple[ColumnName, RowGroup, RowIndex | None]]) -> bool: """Check whether all the given (column, row_group, row_index) tuples are done. A ``row_index`` of ``None`` means the entire batch for that column must @@ -51,20 +51,20 @@ def all_complete(self, cells: list[tuple[str, int, int | None]]) -> bool: def is_ready( self, - column: str, - row_group: int, - row_index: int, + column: ColumnName, + row_group: RowGroup, + row_index: RowIndex, graph: ExecutionGraph, row_group_size: int, ) -> bool: """Check if all upstream columns are done for this (column, row_group, row_index).""" deps = graph.cell_dependencies(column, row_group, row_index, row_group_size) - return self.all_complete(deps) + return self.is_all_complete(deps) def is_batch_ready( self, - column: str, - row_group: int, + column: ColumnName, + row_group: RowGroup, row_group_size: int, graph: ExecutionGraph, ) -> bool: @@ -72,19 +72,19 @@ def is_batch_ready( deps = graph.cell_dependencies(column, row_group, None, row_group_size) # Dropped rows don't need their upstream cells complete deps = [(c, rg, ri) for c, rg, ri in deps if ri is None or not self.is_dropped(rg, ri)] - return self.all_complete(deps) + return self.is_all_complete(deps) - def drop_row(self, row_group: int, row_index: int) -> None: + def drop_row(self, row_group: RowGroup, row_index: RowIndex) -> None: self._dropped[row_group].add(row_index) - def is_dropped(self, row_group: int, row_index: int) -> bool: + def is_dropped(self, row_group: RowGroup, row_index: RowIndex) -> bool: return row_index in self._dropped.get(row_group, set()) def is_row_group_complete( self, - row_group: int, + row_group: RowGroup, row_group_size: int, - all_columns: list[str], + all_columns: list[ColumnName], ) -> bool: """All non-dropped rows have all columns done.""" dropped = self._dropped.get(row_group, set()) @@ -100,7 +100,7 @@ def is_row_group_complete( def get_ready_tasks( self, graph: ExecutionGraph, - row_groups: list[tuple[int, int]], + row_groups: list[tuple[RowGroup, int]], dispatched: set[Task], ) -> list[Task]: """Return all currently dispatchable tasks. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 9dca1d26f..9cffa4f90 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -13,6 +13,7 @@ MultiColumnConfig, ) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError +from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroup, RowIndex @dataclass @@ -24,26 +25,44 @@ class ExecutionGraph: separately by ``CompletionTracker``. """ - _upstream: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) - _downstream: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set)) - _strategies: dict[str, GenerationStrategy] = field(default_factory=dict) - _side_effect_map: dict[str, str] = field(default_factory=dict) - _columns: list[str] = field(default_factory=list) - _topological_order_cache: list[str] | None = field(default=None, repr=False) + _upstream: dict[ColumnName, set[ColumnName]] = field(default_factory=lambda: defaultdict(set)) + _downstream: dict[ColumnName, set[ColumnName]] = field(default_factory=lambda: defaultdict(set)) + _strategies: dict[ColumnName, GenerationStrategy] = field(default_factory=dict) + _side_effect_map: dict[ColumnName, ColumnName] = field(default_factory=dict) + _columns: list[ColumnName] = field(default_factory=list) + _topological_order_cache: list[ColumnName] | None = field(default=None, repr=False) - def upstream(self, column: str) -> set[str]: + def add_column(self, name: ColumnName, strategy: GenerationStrategy) -> None: + """Register a column and its generation strategy.""" + self._columns.append(name) + self._strategies[name] = strategy + + def add_edge(self, upstream: ColumnName, downstream: ColumnName) -> None: + """Add a dependency edge: *downstream* depends on *upstream*.""" + self._upstream[downstream].add(upstream) + self._downstream[upstream].add(downstream) + + def set_side_effect(self, side_effect_col: ColumnName, producer: ColumnName) -> None: + """Map a side-effect column name to its producing column.""" + self._side_effect_map[side_effect_col] = producer + + def resolve_side_effect(self, column: ColumnName) -> ColumnName: + """Resolve a column name through the side-effect map.""" + return self._side_effect_map.get(column, column) + + def upstream(self, column: ColumnName) -> set[ColumnName]: """Direct dependencies of *column*.""" return self._upstream.get(column, set()) - def downstream(self, column: str) -> set[str]: + def downstream(self, column: ColumnName) -> set[ColumnName]: """Columns that depend on *column*.""" return self._downstream.get(column, set()) - def strategy(self, column: str) -> GenerationStrategy: + def strategy(self, column: ColumnName) -> GenerationStrategy: return self._strategies[column] @property - def columns(self) -> list[str]: + def columns(self) -> list[ColumnName]: """All column names in insertion order.""" return list(self._columns) @@ -101,14 +120,14 @@ def critical_path(self) -> list[str]: path.reverse() return path - def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: + def task_count(self, num_records: int, buffer_size: int) -> dict[ColumnName, int]: """Exact task count per column before the run starts. Cell-by-cell columns produce ``num_records`` tasks. Full-column columns (including from-scratch) produce ``ceil(num_records / buffer_size)`` tasks. """ num_row_groups = math.ceil(num_records / buffer_size) - counts: dict[str, int] = {} + counts: dict[ColumnName, int] = {} for col in self._columns: strat = self._strategies[col] if strat == GenerationStrategy.CELL_BY_CELL: @@ -119,17 +138,17 @@ def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: def cell_dependencies( self, - column: str, - row_group: int, - row_index: int | None, + column: ColumnName, + row_group: RowGroup, + row_index: RowIndex | None, row_group_size: int, - ) -> list[tuple[str, int, int | None]]: + ) -> list[tuple[ColumnName, RowGroup, RowIndex | None]]: """Derive cell-level deps on demand from column-level DAG + strategy. Returns a list of ``(upstream_column, row_group, row_index)`` tuples that must be complete before this task can run. """ - deps: list[tuple[str, int, int | None]] = [] + deps: list[tuple[ColumnName, RowGroup, RowIndex | None]] = [] for up_col in self.upstream(column): up_strategy = self._strategies[up_col] if up_strategy == GenerationStrategy.CELL_BY_CELL: @@ -157,7 +176,7 @@ def to_mermaid(self) -> str: def build_execution_graph( column_configs: list[DatasetBuilderColumnConfigT], - strategies: dict[str, GenerationStrategy], + strategies: dict[ColumnName, GenerationStrategy], ) -> ExecutionGraph: """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. @@ -177,13 +196,12 @@ def build_execution_graph( for sub in sub_configs: name = sub.name - graph._columns.append(name) - graph._strategies[name] = strategies[name] + graph.add_column(name, strategies[name]) for se_col in sub.side_effect_columns: - graph._side_effect_map[se_col] = name + graph.set_side_effect(se_col, name) - known_columns = set(graph._columns) + known_columns = set(graph.columns) # Second pass: build edges for config in column_configs: @@ -195,15 +213,14 @@ def build_execution_graph( for sub in sub_configs: name = sub.name for req in sub.required_columns: - resolved = graph._side_effect_map.get(req, req) + resolved = graph.resolve_side_effect(req) if resolved not in known_columns: raise ValueError( f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." ) if resolved == name: continue # skip self-dependency - graph._upstream[name].add(resolved) - graph._downstream[resolved].add(name) + graph.add_edge(upstream=resolved, downstream=name) # Validate acyclicity graph.topological_order() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index 6929dbfbe..6a1f20126 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -4,16 +4,20 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, TypeAlias + +ColumnName: TypeAlias = str +RowGroup: TypeAlias = int +RowIndex: TypeAlias = int @dataclass(frozen=True) class Task: """A unit of work for the async scheduler.""" - column: str - row_group: int - row_index: int | None # None for batch/full-column tasks + column: ColumnName + row_group: RowGroup + row_index: RowIndex | None # None for batch/full-column tasks task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] @@ -32,9 +36,9 @@ class TaskResult: class TaskTrace: """Timing trace for a single task. Only created when tracing is enabled.""" - column: str - row_group: int - row_index: int | None + column: ColumnName + row_group: RowGroup + row_index: RowIndex | None task_type: str dispatched_at: float = 0.0 slot_acquired_at: float = 0.0 @@ -42,9 +46,9 @@ class TaskTrace: status: str = "" error: str | None = None - @staticmethod - def from_task(task: Task) -> TaskTrace: - return TaskTrace( + @classmethod + def from_task(cls, task: Task) -> TaskTrace: + return cls( column=task.column, row_group=task.row_group, row_index=task.row_index, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 77d836154..e442cc4e1 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -55,7 +55,7 @@ def test_mark_batch_complete() -> None: assert not tracker.is_complete("col_a", 0, 3) -# -- all_complete ----------------------------------------------------------- +# -- is_all_complete ----------------------------------------------------------- def test_all_complete_cell_level() -> None: @@ -63,25 +63,25 @@ def test_all_complete_cell_level() -> None: tracker.mark_complete("col_a", 0, 0) tracker.mark_complete("col_a", 0, 1) - assert tracker.all_complete([("col_a", 0, 0), ("col_a", 0, 1)]) - assert not tracker.all_complete([("col_a", 0, 0), ("col_a", 0, 2)]) + assert tracker.is_all_complete([("col_a", 0, 0), ("col_a", 0, 1)]) + assert not tracker.is_all_complete([("col_a", 0, 0), ("col_a", 0, 2)]) def test_all_complete_batch_level() -> None: tracker = CompletionTracker() tracker.mark_batch_complete("col_a", 0, 3) - assert tracker.all_complete([("col_a", 0, None)]) + assert tracker.is_all_complete([("col_a", 0, None)]) def test_all_complete_batch_not_present() -> None: tracker = CompletionTracker() - assert not tracker.all_complete([("col_a", 0, None)]) + assert not tracker.is_all_complete([("col_a", 0, None)]) def test_all_complete_empty_list() -> None: tracker = CompletionTracker() - assert tracker.all_complete([]) + assert tracker.is_all_complete([]) # -- drop_row / is_dropped ------------------------------------------------- From d0d46956e495f1b1a3f572cc2327b761a1e4e09e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 27 Feb 2026 17:47:00 -0300 Subject: [PATCH 04/17] refactor: address remaining PR review feedback - Rename RowGroup type alias to RowGroupIndex for consistency - Convert ExecutionGraph from dataclass to plain class - Move build_execution_graph logic to ExecutionGraph.create() classmethod --- .../utils/completion_tracker.py | 26 ++-- .../dataset_builders/utils/execution_graph.py | 125 ++++++++++-------- .../dataset_builders/utils/task_model.py | 6 +- 3 files changed, 84 insertions(+), 73 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 4d7f4f909..9bb4001e8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroup, RowIndex, Task +from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -22,20 +22,20 @@ class CompletionTracker: def __init__(self) -> None: # row_group → column → set of completed local row indices - self._completed: dict[RowGroup, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) + self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices - self._dropped: dict[RowGroup, set[RowIndex]] = defaultdict(set) + self._dropped: dict[RowGroupIndex, set[RowIndex]] = defaultdict(set) - def mark_complete(self, column: ColumnName, row_group: RowGroup, row_index: RowIndex) -> None: + def mark_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> None: self._completed[row_group][column].add(row_index) - def mark_batch_complete(self, column: ColumnName, row_group: RowGroup, row_group_size: int) -> None: + def mark_batch_complete(self, column: ColumnName, row_group: RowGroupIndex, row_group_size: int) -> None: self._completed[row_group][column] = set(range(row_group_size)) - def is_complete(self, column: ColumnName, row_group: RowGroup, row_index: RowIndex) -> bool: + def is_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> bool: return row_index in self._completed.get(row_group, {}).get(column, set()) - def is_all_complete(self, cells: list[tuple[ColumnName, RowGroup, RowIndex | None]]) -> bool: + def is_all_complete(self, cells: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]) -> bool: """Check whether all the given (column, row_group, row_index) tuples are done. A ``row_index`` of ``None`` means the entire batch for that column must @@ -52,7 +52,7 @@ def is_all_complete(self, cells: list[tuple[ColumnName, RowGroup, RowIndex | Non def is_ready( self, column: ColumnName, - row_group: RowGroup, + row_group: RowGroupIndex, row_index: RowIndex, graph: ExecutionGraph, row_group_size: int, @@ -64,7 +64,7 @@ def is_ready( def is_batch_ready( self, column: ColumnName, - row_group: RowGroup, + row_group: RowGroupIndex, row_group_size: int, graph: ExecutionGraph, ) -> bool: @@ -74,15 +74,15 @@ def is_batch_ready( deps = [(c, rg, ri) for c, rg, ri in deps if ri is None or not self.is_dropped(rg, ri)] return self.is_all_complete(deps) - def drop_row(self, row_group: RowGroup, row_index: RowIndex) -> None: + def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None: self._dropped[row_group].add(row_index) - def is_dropped(self, row_group: RowGroup, row_index: RowIndex) -> bool: + def is_dropped(self, row_group: RowGroupIndex, row_index: RowIndex) -> bool: return row_index in self._dropped.get(row_group, set()) def is_row_group_complete( self, - row_group: RowGroup, + row_group: RowGroupIndex, row_group_size: int, all_columns: list[ColumnName], ) -> bool: @@ -100,7 +100,7 @@ def is_row_group_complete( def get_ready_tasks( self, graph: ExecutionGraph, - row_groups: list[tuple[RowGroup, int]], + row_groups: list[tuple[RowGroupIndex, int]], dispatched: set[Task], ) -> list[Task]: """Return all currently dispatchable tasks. diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 9cffa4f90..3ea01d78b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -5,7 +5,6 @@ import math from collections import defaultdict, deque -from dataclasses import dataclass, field from data_designer.config.column_configs import GenerationStrategy from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -13,10 +12,9 @@ MultiColumnConfig, ) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroup, RowIndex +from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex -@dataclass class ExecutionGraph: """Column-level static execution graph built from column configs. @@ -25,12 +23,13 @@ class ExecutionGraph: separately by ``CompletionTracker``. """ - _upstream: dict[ColumnName, set[ColumnName]] = field(default_factory=lambda: defaultdict(set)) - _downstream: dict[ColumnName, set[ColumnName]] = field(default_factory=lambda: defaultdict(set)) - _strategies: dict[ColumnName, GenerationStrategy] = field(default_factory=dict) - _side_effect_map: dict[ColumnName, ColumnName] = field(default_factory=dict) - _columns: list[ColumnName] = field(default_factory=list) - _topological_order_cache: list[ColumnName] | None = field(default=None, repr=False) + def __init__(self) -> None: + self._upstream: dict[ColumnName, set[ColumnName]] = defaultdict(set) + self._downstream: dict[ColumnName, set[ColumnName]] = defaultdict(set) + self._strategies: dict[ColumnName, GenerationStrategy] = {} + self._side_effect_map: dict[ColumnName, ColumnName] = {} + self._columns: list[ColumnName] = [] + self._topological_order_cache: list[ColumnName] | None = None def add_column(self, name: ColumnName, strategy: GenerationStrategy) -> None: """Register a column and its generation strategy.""" @@ -139,16 +138,16 @@ def task_count(self, num_records: int, buffer_size: int) -> dict[ColumnName, int def cell_dependencies( self, column: ColumnName, - row_group: RowGroup, + row_group: RowGroupIndex, row_index: RowIndex | None, row_group_size: int, - ) -> list[tuple[ColumnName, RowGroup, RowIndex | None]]: + ) -> list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]: """Derive cell-level deps on demand from column-level DAG + strategy. Returns a list of ``(upstream_column, row_group, row_index)`` tuples that must be complete before this task can run. """ - deps: list[tuple[ColumnName, RowGroup, RowIndex | None]] = [] + deps: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]] = [] for up_col in self.upstream(column): up_strategy = self._strategies[up_col] if up_strategy == GenerationStrategy.CELL_BY_CELL: @@ -173,6 +172,61 @@ def to_mermaid(self) -> str: lines.append(f" {dep} --> {col}") return "\n".join(lines) + @classmethod + def create( + cls, + column_configs: list[DatasetBuilderColumnConfigT], + strategies: dict[ColumnName, GenerationStrategy], + ) -> ExecutionGraph: + """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. + + Args: + column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. + strategies: Map of column name → ``GenerationStrategy``, obtained from + each generator's ``get_generation_strategy()``. + """ + graph = cls() + + # First pass: register all columns, strategies, and side-effect mappings + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + graph.add_column(name, strategies[name]) + + for se_col in sub.side_effect_columns: + graph.set_side_effect(se_col, name) + + known_columns = set(graph.columns) + + # Second pass: build edges + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + for req in sub.required_columns: + resolved = graph.resolve_side_effect(req) + if resolved not in known_columns: + raise ValueError( + f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." + ) + if resolved == name: + continue # skip self-dependency + graph.add_edge(upstream=resolved, downstream=name) + + # Validate acyclicity + graph.topological_order() + + return graph + def build_execution_graph( column_configs: list[DatasetBuilderColumnConfigT], @@ -180,49 +234,6 @@ def build_execution_graph( ) -> ExecutionGraph: """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. - Args: - column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. - strategies: Map of column name → ``GenerationStrategy``, obtained from - each generator's ``get_generation_strategy()``. + .. deprecated:: Use ``ExecutionGraph.create()`` instead. """ - graph = ExecutionGraph() - - # First pass: register all columns, strategies, and side-effect mappings - for config in column_configs: - if isinstance(config, MultiColumnConfig): - sub_configs = config.columns - else: - sub_configs = [config] - - for sub in sub_configs: - name = sub.name - graph.add_column(name, strategies[name]) - - for se_col in sub.side_effect_columns: - graph.set_side_effect(se_col, name) - - known_columns = set(graph.columns) - - # Second pass: build edges - for config in column_configs: - if isinstance(config, MultiColumnConfig): - sub_configs = config.columns - else: - sub_configs = [config] - - for sub in sub_configs: - name = sub.name - for req in sub.required_columns: - resolved = graph.resolve_side_effect(req) - if resolved not in known_columns: - raise ValueError( - f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." - ) - if resolved == name: - continue # skip self-dependency - graph.add_edge(upstream=resolved, downstream=name) - - # Validate acyclicity - graph.topological_order() - - return graph + return ExecutionGraph.create(column_configs, strategies) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index 6a1f20126..be95c3e43 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -7,7 +7,7 @@ from typing import Any, Literal, TypeAlias ColumnName: TypeAlias = str -RowGroup: TypeAlias = int +RowGroupIndex: TypeAlias = int RowIndex: TypeAlias = int @@ -16,7 +16,7 @@ class Task: """A unit of work for the async scheduler.""" column: ColumnName - row_group: RowGroup + row_group: RowGroupIndex row_index: RowIndex | None # None for batch/full-column tasks task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] @@ -37,7 +37,7 @@ class TaskTrace: """Timing trace for a single task. Only created when tracing is enabled.""" column: ColumnName - row_group: RowGroup + row_group: RowGroupIndex row_index: RowIndex | None task_type: str dispatched_at: float = 0.0 From c30abb4e57c2907a1110f1e71dd4b28ab1968d4e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 2 Mar 2026 17:01:26 -0300 Subject: [PATCH 05/17] refactor: event-driven frontier for CompletionTracker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the poll-based get_ready_tasks (O(C × R × G) per tick) with an event-driven frontier maintained on mark_complete/mark_batch_complete/ drop_row. get_ready_tasks now returns O(frontier) instead of scanning all columns × rows × row groups. --- .../utils/completion_tracker.py | 177 ++++++++++++------ .../dataset_builders/utils/execution_graph.py | 35 +++- .../utils/test_completion_tracker.py | 34 ++-- 3 files changed, 162 insertions(+), 84 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 9bb4001e8..2f4df2b75 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -18,19 +18,93 @@ class CompletionTracker: All access is from the single asyncio event loop thread — no locks needed. Row indices are local to their row group (0-based). + + When *graph* and *row_groups* are provided, an event-driven frontier is + maintained so that ``get_ready_tasks`` returns in O(frontier) instead of + scanning all columns × rows × row groups. """ - def __init__(self) -> None: + def __init__( + self, + graph: ExecutionGraph | None = None, + row_groups: list[tuple[RowGroupIndex, int]] | None = None, + ) -> None: # row_group → column → set of completed local row indices self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices self._dropped: dict[RowGroupIndex, set[RowIndex]] = defaultdict(set) + self._graph = graph + self._row_group_sizes: dict[RowGroupIndex, int] = {} + self._frontier: set[Task] = set() + + if graph is not None and row_groups is not None: + self._row_group_sizes = {rg_id: size for rg_id, size in row_groups} + self._seed_frontier() + + def _seed_frontier(self) -> None: + """Populate the frontier with root tasks (columns with no upstream deps).""" + assert self._graph is not None + for col in self._graph.topological_order(): + if self._graph.upstream(col): + continue + strategy = self._graph.strategy(col) + for rg_id, rg_size in self._row_group_sizes.items(): + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(rg_size): + self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) + else: + self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) + def mark_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> None: self._completed[row_group][column].add(row_index) + if self._graph is not None: + self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) + self._enqueue_downstream(column, row_group, row_index=row_index) def mark_batch_complete(self, column: ColumnName, row_group: RowGroupIndex, row_group_size: int) -> None: self._completed[row_group][column] = set(range(row_group_size)) + if self._graph is not None: + self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) + self._enqueue_downstream(column, row_group, row_index=None) + + def _enqueue_downstream(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex | None) -> None: + """Add newly-ready downstream tasks to the frontier.""" + assert self._graph is not None + rg_completed = self._completed.get(row_group, {}) + rg_dropped = self._dropped.get(row_group, set()) + rg_size = self._row_group_sizes[row_group] + + for down in self._graph.downstream(column): + batch_ups, cell_ups = self._graph.upstream_by_strategy(down) + + # All batch upstreams must be present in completed dict + if any(up not in rg_completed for up in batch_ups): + continue + + down_strategy = self._graph.strategy(down) + + if down_strategy == GenerationStrategy.CELL_BY_CELL: + cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups] + if row_index is not None: + # Cell completion: only check the same row + if row_index not in rg_dropped and all(row_index in s for s in cell_up_completed): + task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell") + self._frontier.add(task) + else: + # Batch completion: check all non-dropped, non-complete rows + down_completed = rg_completed.get(down, set()) + for ri in range(rg_size): + if ri in rg_dropped or ri in down_completed: + continue + if all(ri in s for s in cell_up_completed): + task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell") + self._frontier.add(task) + else: + # FULL_COLUMN downstream: ready when all cell upstreams are fully complete + if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): + task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") + self._frontier.add(task) def is_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> bool: return row_index in self._completed.get(row_group, {}).get(column, set()) @@ -49,33 +123,33 @@ def is_all_complete(self, cells: list[tuple[ColumnName, RowGroupIndex, RowIndex return False return True - def is_ready( - self, - column: ColumnName, - row_group: RowGroupIndex, - row_index: RowIndex, - graph: ExecutionGraph, - row_group_size: int, - ) -> bool: - """Check if all upstream columns are done for this (column, row_group, row_index).""" - deps = graph.cell_dependencies(column, row_group, row_index, row_group_size) - return self.is_all_complete(deps) - - def is_batch_ready( - self, - column: ColumnName, - row_group: RowGroupIndex, - row_group_size: int, - graph: ExecutionGraph, - ) -> bool: - """Check if all upstream columns are done for all non-dropped rows in the row group.""" - deps = graph.cell_dependencies(column, row_group, None, row_group_size) - # Dropped rows don't need their upstream cells complete - deps = [(c, rg, ri) for c, rg, ri in deps if ri is None or not self.is_dropped(rg, ri)] - return self.is_all_complete(deps) - def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None: self._dropped[row_group].add(row_index) + if self._graph is not None: + # Remove cell tasks for this row from the frontier + for col in self._graph.columns: + self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) + # Dropping a row may unblock batch downstream tasks + self._reevaluate_batch_tasks(row_group) + + def _reevaluate_batch_tasks(self, row_group: RowGroupIndex) -> None: + """Check if any batch tasks became ready after a row was dropped.""" + assert self._graph is not None + rg_completed = self._completed.get(row_group, {}) + rg_dropped = self._dropped.get(row_group, set()) + rg_size = self._row_group_sizes[row_group] + + for col in self._graph.topological_order(): + if self._graph.strategy(col) != GenerationStrategy.FULL_COLUMN: + continue + if col in rg_completed: + continue + batch_ups, cell_ups = self._graph.upstream_by_strategy(col) + if any(up not in rg_completed for up in batch_ups): + continue + if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): + task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") + self._frontier.add(task) def is_dropped(self, row_group: RowGroupIndex, row_index: RowIndex) -> bool: return row_index in self._dropped.get(row_group, set()) @@ -97,37 +171,24 @@ def is_row_group_complete( return False return True - def get_ready_tasks( - self, - graph: ExecutionGraph, - row_groups: list[tuple[RowGroupIndex, int]], - dispatched: set[Task], - ) -> list[Task]: - """Return all currently dispatchable tasks. + def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: + """Return all currently dispatchable tasks from the frontier. - Excludes dropped rows and already-dispatched/in-flight tasks. + Excludes already-dispatched/in-flight tasks. """ - ready: list[Task] = [] - for rg_id, rg_size in row_groups: - for col in graph.topological_order(): - strategy = graph.strategy(col) - if strategy == GenerationStrategy.CELL_BY_CELL: - for ri in range(rg_size): - if self.is_dropped(rg_id, ri): - continue - if self.is_complete(col, rg_id, ri): - continue - task = Task(column=col, row_group=rg_id, row_index=ri, task_type="cell") - if task in dispatched: - continue - if self.is_ready(col, rg_id, ri, graph, rg_size): - ready.append(task) - else: - task = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") - if task in dispatched: - continue - if col in self._completed.get(rg_id, {}): - continue - if self.is_batch_ready(col, rg_id, rg_size, graph): - ready.append(task) - return ready + return [t for t in self._frontier if t not in dispatched] + + def _are_cell_ups_complete( + self, + cell_ups: list[ColumnName], + rg_completed: dict[ColumnName, set[RowIndex]], + rg_size: int, + rg_dropped: set[RowIndex], + ) -> bool: + """Check all non-dropped rows are complete for each cell-by-cell upstream column.""" + for up in cell_ups: + up_completed = rg_completed.get(up, set()) + for ri in range(rg_size): + if ri not in rg_dropped and ri not in up_completed: + return False + return True diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 3ea01d78b..4667ae155 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -4,7 +4,7 @@ from __future__ import annotations import math -from collections import defaultdict, deque +from collections import deque from data_designer.config.column_configs import GenerationStrategy from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -24,12 +24,13 @@ class ExecutionGraph: """ def __init__(self) -> None: - self._upstream: dict[ColumnName, set[ColumnName]] = defaultdict(set) - self._downstream: dict[ColumnName, set[ColumnName]] = defaultdict(set) + self._upstream: dict[ColumnName, set[ColumnName]] = {} + self._downstream: dict[ColumnName, set[ColumnName]] = {} self._strategies: dict[ColumnName, GenerationStrategy] = {} self._side_effect_map: dict[ColumnName, ColumnName] = {} self._columns: list[ColumnName] = [] self._topological_order_cache: list[ColumnName] | None = None + self._upstream_by_strategy_cache: dict[ColumnName, tuple[list[ColumnName], list[ColumnName]]] = {} def add_column(self, name: ColumnName, strategy: GenerationStrategy) -> None: """Register a column and its generation strategy.""" @@ -38,8 +39,8 @@ def add_column(self, name: ColumnName, strategy: GenerationStrategy) -> None: def add_edge(self, upstream: ColumnName, downstream: ColumnName) -> None: """Add a dependency edge: *downstream* depends on *upstream*.""" - self._upstream[downstream].add(upstream) - self._downstream[upstream].add(downstream) + self._upstream.setdefault(downstream, set()).add(upstream) + self._downstream.setdefault(upstream, set()).add(downstream) def set_side_effect(self, side_effect_col: ColumnName, producer: ColumnName) -> None: """Map a side-effect column name to its producing column.""" @@ -60,10 +61,26 @@ def downstream(self, column: ColumnName) -> set[ColumnName]: def strategy(self, column: ColumnName) -> GenerationStrategy: return self._strategies[column] + def upstream_by_strategy(self, column: ColumnName) -> tuple[list[ColumnName], list[ColumnName]]: + """Split upstream columns into (batch, cell_by_cell) by strategy. Cached.""" + cached = self._upstream_by_strategy_cache.get(column) + if cached is not None: + return cached + batch: list[ColumnName] = [] + cell: list[ColumnName] = [] + for up_col in self.upstream(column): + if self._strategies[up_col] == GenerationStrategy.CELL_BY_CELL: + cell.append(up_col) + else: + batch.append(up_col) + result = (batch, cell) + self._upstream_by_strategy_cache[column] = result + return result + @property def columns(self) -> list[ColumnName]: - """All column names in insertion order.""" - return list(self._columns) + """All column names in insertion order. Do not mutate.""" + return self._columns def topological_order(self) -> list[str]: """Return a valid topological ordering of columns (Kahn's algorithm). @@ -72,7 +89,7 @@ def topological_order(self) -> list[str]: immutable after construction. """ if self._topological_order_cache is not None: - return list(self._topological_order_cache) + return self._topological_order_cache in_degree: dict[str, int] = {col: 0 for col in self._columns} for col, deps in self._upstream.items(): @@ -96,7 +113,7 @@ def topological_order(self) -> list[str]: ) self._topological_order_cache = order - return list(order) + return order def critical_path(self) -> list[str]: """Longest dependency chain (by number of columns).""" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index e442cc4e1..1d054ab87 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -140,10 +140,10 @@ def test_row_group_not_complete_missing_non_dropped() -> None: def test_get_ready_tasks_seeds_first() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) # Only the seed column should be ready (no upstream) assert len(ready) == 1 @@ -153,12 +153,12 @@ def test_get_ready_tasks_seeds_first() -> None: def test_get_ready_tasks_after_seed_complete() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) # All question cells should be ready (topic is done) question_tasks = [t for t in ready if t.column == "question"] @@ -169,27 +169,27 @@ def test_get_ready_tasks_after_seed_complete() -> None: def test_get_ready_tasks_skips_dispatched() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) - ready1 = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready1 = tracker.get_ready_tasks(dispatched) dispatched.update(ready1) - ready2 = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready2 = tracker.get_ready_tasks(dispatched) assert len(ready2) == 0 def test_get_ready_tasks_skips_dropped_rows() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) tracker.drop_row(0, 1) - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 2 @@ -198,7 +198,7 @@ def test_get_ready_tasks_skips_dropped_rows() -> None: def test_get_ready_tasks_full_column_waits_for_all_cells() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) @@ -206,7 +206,7 @@ def test_get_ready_tasks_full_column_waits_for_all_cells() -> None: tracker.mark_complete("question", 0, 1) # question[0,2] not done yet - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 # score waits for all question rows @@ -214,14 +214,14 @@ def test_get_ready_tasks_full_column_waits_for_all_cells() -> None: def test_get_ready_tasks_full_column_ready_when_all_cells_done() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) for ri in range(3): tracker.mark_complete("question", 0, ri) - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 @@ -230,14 +230,14 @@ def test_get_ready_tasks_full_column_ready_when_all_cells_done() -> None: def test_get_ready_tasks_multiple_row_groups() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3), (1, 2)]) dispatched: set[Task] = set() # Both row groups have topic done tracker.mark_batch_complete("topic", 0, 3) tracker.mark_batch_complete("topic", 1, 2) - ready = tracker.get_ready_tasks(graph, [(0, 3), (1, 2)], dispatched) + ready = tracker.get_ready_tasks(dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 @@ -245,12 +245,12 @@ def test_get_ready_tasks_multiple_row_groups() -> None: def test_get_ready_tasks_skips_already_complete_batch() -> None: graph = _build_simple_graph() - tracker = CompletionTracker() + tracker = CompletionTracker(graph, [(0, 3)]) dispatched: set[Task] = set() tracker.mark_batch_complete("topic", 0, 3) - ready = tracker.get_ready_tasks(graph, [(0, 3)], dispatched) + ready = tracker.get_ready_tasks(dispatched) # topic is already complete, should not be in ready tasks topic_tasks = [t for t in ready if t.column == "topic"] From 638c878b9b4b5c64f4890081c6f3f5a9029ba3a3 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Tue, 3 Mar 2026 13:24:12 -0300 Subject: [PATCH 06/17] refactor: extract ready_ctx fixture in completion tracker tests - Add ReadyTasksFixture dataclass and ready_ctx pytest fixture to deduplicate graph/tracker/dispatched setup across get_ready_tasks tests - Align test with ExecutionGraph.create API rename - Remove redundant inline comments --- .../utils/test_completion_tracker.py | 111 ++++++++---------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 1d054ab87..7e9a9e302 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 -# Helpers to build minimal graphs without real column configs +from dataclasses import dataclass + +import pytest + from data_designer.config.column_configs import ( ExpressionColumnConfig, GenerationStrategy, @@ -11,7 +14,7 @@ ) from data_designer.config.sampler_params import SamplerType from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker -from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph, build_execution_graph +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.task_model import Task MODEL_ALIAS = "stub" @@ -29,7 +32,23 @@ def _build_simple_graph() -> ExecutionGraph: "question": GenerationStrategy.CELL_BY_CELL, "score": GenerationStrategy.FULL_COLUMN, } - return build_execution_graph(configs, strategies) + return ExecutionGraph.create(configs, strategies) + + +@dataclass +class ReadyTasksFixture: + tracker: CompletionTracker + dispatched: set[Task] + + +@pytest.fixture() +def ready_ctx() -> ReadyTasksFixture: + """CompletionTracker wired to the simple 3-column graph with one row group of size 3.""" + graph = _build_simple_graph() + return ReadyTasksFixture( + tracker=CompletionTracker(graph, [(0, 3)]), + dispatched=set(), + ) # -- mark_complete / is_complete ------------------------------------------- @@ -138,90 +157,64 @@ def test_row_group_not_complete_missing_non_dropped() -> None: # -- get_ready_tasks -------------------------------------------------------- -def test_get_ready_tasks_seeds_first() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() - - ready = tracker.get_ready_tasks(dispatched) +def test_get_ready_tasks_seeds_first(ready_ctx: ReadyTasksFixture) -> None: + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) - # Only the seed column should be ready (no upstream) assert len(ready) == 1 assert ready[0].column == "topic" assert ready[0].task_type == "batch" -def test_get_ready_tasks_after_seed_complete() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() +def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) - tracker.mark_batch_complete("topic", 0, 3) - - ready = tracker.get_ready_tasks(dispatched) + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) - # All question cells should be ready (topic is done) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 3 assert all(t.task_type == "cell" for t in question_tasks) assert {t.row_index for t in question_tasks} == {0, 1, 2} -def test_get_ready_tasks_skips_dispatched() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() +def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) - tracker.mark_batch_complete("topic", 0, 3) + ready1 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + ready_ctx.dispatched.update(ready1) - ready1 = tracker.get_ready_tasks(dispatched) - dispatched.update(ready1) - - ready2 = tracker.get_ready_tasks(dispatched) + ready2 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) assert len(ready2) == 0 -def test_get_ready_tasks_skips_dropped_rows() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() +def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.drop_row(0, 1) - tracker.mark_batch_complete("topic", 0, 3) - tracker.drop_row(0, 1) - - ready = tracker.get_ready_tasks(dispatched) + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) question_tasks = [t for t in ready if t.column == "question"] assert len(question_tasks) == 2 assert {t.row_index for t in question_tasks} == {0, 2} -def test_get_ready_tasks_full_column_waits_for_all_cells() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() - - tracker.mark_batch_complete("topic", 0, 3) - tracker.mark_complete("question", 0, 0) - tracker.mark_complete("question", 0, 1) +def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_complete("question", 0, 0) + ready_ctx.tracker.mark_complete("question", 0, 1) # question[0,2] not done yet - ready = tracker.get_ready_tasks(dispatched) + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] - assert len(score_tasks) == 0 # score waits for all question rows + assert len(score_tasks) == 0 -def test_get_ready_tasks_full_column_ready_when_all_cells_done() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() - - tracker.mark_batch_complete("topic", 0, 3) +def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) for ri in range(3): - tracker.mark_complete("question", 0, ri) + ready_ctx.tracker.mark_complete("question", 0, ri) - ready = tracker.get_ready_tasks(dispatched) + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 1 @@ -233,7 +226,6 @@ def test_get_ready_tasks_multiple_row_groups() -> None: tracker = CompletionTracker(graph, [(0, 3), (1, 2)]) dispatched: set[Task] = set() - # Both row groups have topic done tracker.mark_batch_complete("topic", 0, 3) tracker.mark_batch_complete("topic", 1, 2) @@ -243,15 +235,10 @@ def test_get_ready_tasks_multiple_row_groups() -> None: assert len(question_tasks) == 5 # 3 from rg0 + 2 from rg1 -def test_get_ready_tasks_skips_already_complete_batch() -> None: - graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3)]) - dispatched: set[Task] = set() +def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_batch_complete("topic", 0, 3) - tracker.mark_batch_complete("topic", 0, 3) - - ready = tracker.get_ready_tasks(dispatched) + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) - # topic is already complete, should not be in ready tasks topic_tasks = [t for t in ready if t.column == "topic"] assert len(topic_tasks) == 0 From b08cb3d5f5330cfc6f59c94a141caf85c11fb800 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Tue, 3 Mar 2026 15:26:02 -0300 Subject: [PATCH 07/17] fix: validate tracker args and resolve side-effect name collisions - CompletionTracker now raises ValueError when graph/row_groups are provided without each other - resolve_side_effect prefers real columns over aliases when a name collision exists --- .../utils/completion_tracker.py | 3 +++ .../dataset_builders/utils/execution_graph.py | 8 ++++++- .../utils/test_completion_tracker.py | 11 +++++++++ .../utils/test_execution_graph.py | 23 +++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 2f4df2b75..b47507879 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -29,6 +29,9 @@ def __init__( graph: ExecutionGraph | None = None, row_groups: list[tuple[RowGroupIndex, int]] | None = None, ) -> None: + if (graph is None) != (row_groups is None): + raise ValueError("`graph` and `row_groups` must be provided together.") + # row_group → column → set of completed local row indices self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 4667ae155..241e3676b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -47,7 +47,13 @@ def set_side_effect(self, side_effect_col: ColumnName, producer: ColumnName) -> self._side_effect_map[side_effect_col] = producer def resolve_side_effect(self, column: ColumnName) -> ColumnName: - """Resolve a column name through the side-effect map.""" + """Resolve a column name through the side-effect map. + + If a real column exists with the same name as a side-effect alias, + the real column wins. + """ + if column in self._strategies: + return column return self._side_effect_map.get(column, column) def upstream(self, column: ColumnName) -> set[ColumnName]: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 7e9a9e302..c780a5532 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -51,6 +51,17 @@ def ready_ctx() -> ReadyTasksFixture: ) +def test_tracker_requires_row_groups_with_graph() -> None: + graph = _build_simple_graph() + with pytest.raises(ValueError, match="provided together"): + CompletionTracker(graph=graph, row_groups=None) + + +def test_tracker_requires_graph_with_row_groups() -> None: + with pytest.raises(ValueError, match="provided together"): + CompletionTracker(graph=None, row_groups=[(0, 3)]) + + # -- mark_complete / is_complete ------------------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 921752140..454397930 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -133,6 +133,29 @@ def test_reasoning_content_side_effect() -> None: assert graph.upstream("reasoning") == {"answer"} +def test_side_effect_name_collision_prefers_real_column() -> None: + configs = [ + LLMTextColumnConfig( + name="summary", + prompt="Summarize", + model_alias=MODEL_ALIAS, + with_trace="last_message", + ), + SamplerColumnConfig(name="summary__trace", sampler_type=SamplerType.CATEGORY, params={"values": ["OVERRIDE"]}), + ExpressionColumnConfig(name="trace_len", expr="{{ summary__trace }}"), + ] + strategies = { + "summary": GenerationStrategy.CELL_BY_CELL, + "summary__trace": GenerationStrategy.FULL_COLUMN, + "trace_len": GenerationStrategy.FULL_COLUMN, + } + graph = build_execution_graph(configs, strategies) + + assert graph.upstream("trace_len") == {"summary__trace"} + assert graph.downstream("summary__trace") == {"trace_len"} + assert graph.downstream("summary") == set() + + # -- Validation tests ------------------------------------------------------- From ebb94285991e65e3b580211443f6d69d92468ae9 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 4 Mar 2026 19:37:28 -0300 Subject: [PATCH 08/17] =?UTF-8?q?refactor:=20address=20PR=20review=20feedb?= =?UTF-8?q?ack=20=E2=80=94=20naming,=20CellRef,=20batch=20semantics?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix critical_path() crash on empty graph (early return) - Fix is_all_complete batch semantics via _batch_complete tracking set - Add row-group size mismatch validation in mark_row_range_complete - Add unknown row_group validation in mark_cell_complete - Rename methods for verb-prefix convention: upstream → get_upstream_columns, downstream → get_downstream_columns, critical_path → get_longest_dependency_chain, mark_complete → mark_cell_complete, mark_batch_complete → mark_row_range_complete - Introduce CellRef NamedTuple, remove ColumnName/RowGroupIndex/RowIndex aliases - Delete deprecated build_execution_graph() wrapper - Return defensive copy from topological_order() - Add regression tests for fixed bugs --- .../utils/completion_tracker.py | 65 +++++++----- .../dataset_builders/utils/execution_graph.py | 88 ++++++++--------- .../dataset_builders/utils/task_model.py | 24 +++-- .../utils/test_completion_tracker.py | 82 +++++++++------ .../utils/test_execution_graph.py | 99 ++++++++++--------- 5 files changed, 195 insertions(+), 163 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index b47507879..b4e90c346 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -7,14 +7,14 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex, Task +from data_designer.engine.dataset_builders.utils.task_model import CellRef, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph class CompletionTracker: - """Tracks which (column, row_group, row_index) tuples are done. + """Tracks which cells (column, row_group, row_index) are done. All access is from the single asyncio event loop thread — no locks needed. Row indices are local to their row group (0-based). @@ -27,18 +27,19 @@ class CompletionTracker: def __init__( self, graph: ExecutionGraph | None = None, - row_groups: list[tuple[RowGroupIndex, int]] | None = None, + row_groups: list[tuple[int, int]] | None = None, ) -> None: if (graph is None) != (row_groups is None): raise ValueError("`graph` and `row_groups` must be provided together.") # row_group → column → set of completed local row indices - self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) + self._completed: dict[int, dict[str, set[int]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices - self._dropped: dict[RowGroupIndex, set[RowIndex]] = defaultdict(set) + self._dropped: dict[int, set[int]] = defaultdict(set) self._graph = graph - self._row_group_sizes: dict[RowGroupIndex, int] = {} + self._row_group_sizes: dict[int, int] = {} + self._batch_complete: dict[int, set[str]] = defaultdict(set) self._frontier: set[Task] = set() if graph is not None and row_groups is not None: @@ -49,7 +50,7 @@ def _seed_frontier(self) -> None: """Populate the frontier with root tasks (columns with no upstream deps).""" assert self._graph is not None for col in self._graph.topological_order(): - if self._graph.upstream(col): + if self._graph.get_upstream_columns(col): continue strategy = self._graph.strategy(col) for rg_id, rg_size in self._row_group_sizes.items(): @@ -59,26 +60,31 @@ def _seed_frontier(self) -> None: else: self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) - def mark_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> None: + def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: + self._validate_row_group(row_group) self._completed[row_group][column].add(row_index) if self._graph is not None: self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) self._enqueue_downstream(column, row_group, row_index=row_index) - def mark_batch_complete(self, column: ColumnName, row_group: RowGroupIndex, row_group_size: int) -> None: + def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None: + expected = self._validate_row_group(row_group) + if expected is not None and row_group_size != expected: + raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") self._completed[row_group][column] = set(range(row_group_size)) + self._batch_complete[row_group].add(column) if self._graph is not None: self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) self._enqueue_downstream(column, row_group, row_index=None) - def _enqueue_downstream(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex | None) -> None: + def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None: """Add newly-ready downstream tasks to the frontier.""" assert self._graph is not None rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_size = self._row_group_sizes[row_group] - for down in self._graph.downstream(column): + for down in self._graph.get_downstream_columns(column): batch_ups, cell_ups = self._graph.upstream_by_strategy(down) # All batch upstreams must be present in completed dict @@ -109,24 +115,25 @@ def _enqueue_downstream(self, column: ColumnName, row_group: RowGroupIndex, row_ task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") self._frontier.add(task) - def is_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> bool: + def is_complete(self, column: str, row_group: int, row_index: int) -> bool: return row_index in self._completed.get(row_group, {}).get(column, set()) - def is_all_complete(self, cells: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]) -> bool: - """Check whether all the given (column, row_group, row_index) tuples are done. + def is_all_complete(self, cells: list[CellRef]) -> bool: + """Check whether all the given cells are done. A ``row_index`` of ``None`` means the entire batch for that column must - be complete (i.e., that column key must exist in the row group's dict). + have been completed via ``mark_row_range_complete``. """ for col, rg, ri in cells: if ri is None: - if col not in self._completed.get(rg, {}): + if col not in self._batch_complete.get(rg, set()): return False elif not self.is_complete(col, rg, ri): return False return True - def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None: + def drop_row(self, row_group: int, row_index: int) -> None: + self._validate_row_group(row_group) self._dropped[row_group].add(row_index) if self._graph is not None: # Remove cell tasks for this row from the frontier @@ -135,7 +142,7 @@ def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None: # Dropping a row may unblock batch downstream tasks self._reevaluate_batch_tasks(row_group) - def _reevaluate_batch_tasks(self, row_group: RowGroupIndex) -> None: + def _reevaluate_batch_tasks(self, row_group: int) -> None: """Check if any batch tasks became ready after a row was dropped.""" assert self._graph is not None rg_completed = self._completed.get(row_group, {}) @@ -154,14 +161,14 @@ def _reevaluate_batch_tasks(self, row_group: RowGroupIndex) -> None: task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") self._frontier.add(task) - def is_dropped(self, row_group: RowGroupIndex, row_index: RowIndex) -> bool: + def is_dropped(self, row_group: int, row_index: int) -> bool: return row_index in self._dropped.get(row_group, set()) def is_row_group_complete( self, - row_group: RowGroupIndex, + row_group: int, row_group_size: int, - all_columns: list[ColumnName], + all_columns: list[str], ) -> bool: """All non-dropped rows have all columns done.""" dropped = self._dropped.get(row_group, set()) @@ -183,10 +190,10 @@ def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: def _are_cell_ups_complete( self, - cell_ups: list[ColumnName], - rg_completed: dict[ColumnName, set[RowIndex]], + cell_ups: list[str], + rg_completed: dict[str, set[int]], rg_size: int, - rg_dropped: set[RowIndex], + rg_dropped: set[int], ) -> bool: """Check all non-dropped rows are complete for each cell-by-cell upstream column.""" for up in cell_ups: @@ -195,3 +202,13 @@ def _are_cell_ups_complete( if ri not in rg_dropped and ri not in up_completed: return False return True + + def _validate_row_group(self, row_group: int) -> int | None: + """Validate row-group id in graph-enabled mode and return its expected size.""" + if self._graph is None: + return None + expected = self._row_group_sizes.get(row_group) + if expected is None: + known = sorted(self._row_group_sizes) + raise ValueError(f"Unknown row_group {row_group}. Known row_groups: {known}") + return expected diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 241e3676b..684621541 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -12,7 +12,7 @@ MultiColumnConfig, ) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex +from data_designer.engine.dataset_builders.utils.task_model import CellRef class ExecutionGraph: @@ -24,29 +24,29 @@ class ExecutionGraph: """ def __init__(self) -> None: - self._upstream: dict[ColumnName, set[ColumnName]] = {} - self._downstream: dict[ColumnName, set[ColumnName]] = {} - self._strategies: dict[ColumnName, GenerationStrategy] = {} - self._side_effect_map: dict[ColumnName, ColumnName] = {} - self._columns: list[ColumnName] = [] - self._topological_order_cache: list[ColumnName] | None = None - self._upstream_by_strategy_cache: dict[ColumnName, tuple[list[ColumnName], list[ColumnName]]] = {} - - def add_column(self, name: ColumnName, strategy: GenerationStrategy) -> None: + self._upstream: dict[str, set[str]] = {} + self._downstream: dict[str, set[str]] = {} + self._strategies: dict[str, GenerationStrategy] = {} + self._side_effect_map: dict[str, str] = {} + self._columns: list[str] = [] + self._topological_order_cache: list[str] | None = None + self._upstream_by_strategy_cache: dict[str, tuple[list[str], list[str]]] = {} + + def add_column(self, name: str, strategy: GenerationStrategy) -> None: """Register a column and its generation strategy.""" self._columns.append(name) self._strategies[name] = strategy - def add_edge(self, upstream: ColumnName, downstream: ColumnName) -> None: + def add_edge(self, upstream: str, downstream: str) -> None: """Add a dependency edge: *downstream* depends on *upstream*.""" self._upstream.setdefault(downstream, set()).add(upstream) self._downstream.setdefault(upstream, set()).add(downstream) - def set_side_effect(self, side_effect_col: ColumnName, producer: ColumnName) -> None: + def set_side_effect(self, side_effect_col: str, producer: str) -> None: """Map a side-effect column name to its producing column.""" self._side_effect_map[side_effect_col] = producer - def resolve_side_effect(self, column: ColumnName) -> ColumnName: + def resolve_side_effect(self, column: str) -> str: """Resolve a column name through the side-effect map. If a real column exists with the same name as a side-effect alias, @@ -56,25 +56,25 @@ def resolve_side_effect(self, column: ColumnName) -> ColumnName: return column return self._side_effect_map.get(column, column) - def upstream(self, column: ColumnName) -> set[ColumnName]: + def get_upstream_columns(self, column: str) -> set[str]: """Direct dependencies of *column*.""" return self._upstream.get(column, set()) - def downstream(self, column: ColumnName) -> set[ColumnName]: + def get_downstream_columns(self, column: str) -> set[str]: """Columns that depend on *column*.""" return self._downstream.get(column, set()) - def strategy(self, column: ColumnName) -> GenerationStrategy: + def strategy(self, column: str) -> GenerationStrategy: return self._strategies[column] - def upstream_by_strategy(self, column: ColumnName) -> tuple[list[ColumnName], list[ColumnName]]: + def upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: """Split upstream columns into (batch, cell_by_cell) by strategy. Cached.""" cached = self._upstream_by_strategy_cache.get(column) if cached is not None: return cached - batch: list[ColumnName] = [] - cell: list[ColumnName] = [] - for up_col in self.upstream(column): + batch: list[str] = [] + cell: list[str] = [] + for up_col in self.get_upstream_columns(column): if self._strategies[up_col] == GenerationStrategy.CELL_BY_CELL: cell.append(up_col) else: @@ -84,7 +84,7 @@ def upstream_by_strategy(self, column: ColumnName) -> tuple[list[ColumnName], li return result @property - def columns(self) -> list[ColumnName]: + def columns(self) -> list[str]: """All column names in insertion order. Do not mutate.""" return self._columns @@ -95,7 +95,7 @@ def topological_order(self) -> list[str]: immutable after construction. """ if self._topological_order_cache is not None: - return self._topological_order_cache + return list(self._topological_order_cache) in_degree: dict[str, int] = {col: 0 for col in self._columns} for col, deps in self._upstream.items(): @@ -119,11 +119,13 @@ def topological_order(self) -> list[str]: ) self._topological_order_cache = order - return order + return list(order) - def critical_path(self) -> list[str]: + def get_longest_dependency_chain(self) -> list[str]: """Longest dependency chain (by number of columns).""" order = self.topological_order() + if not order: + return [] dist: dict[str, int] = {col: 0 for col in order} pred: dict[str, str | None] = {col: None for col in order} @@ -142,14 +144,14 @@ def critical_path(self) -> list[str]: path.reverse() return path - def task_count(self, num_records: int, buffer_size: int) -> dict[ColumnName, int]: + def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: """Exact task count per column before the run starts. Cell-by-cell columns produce ``num_records`` tasks. Full-column columns (including from-scratch) produce ``ceil(num_records / buffer_size)`` tasks. """ num_row_groups = math.ceil(num_records / buffer_size) - counts: dict[ColumnName, int] = {} + counts: dict[str, int] = {} for col in self._columns: strat = self._strategies[col] if strat == GenerationStrategy.CELL_BY_CELL: @@ -160,27 +162,26 @@ def task_count(self, num_records: int, buffer_size: int) -> dict[ColumnName, int def cell_dependencies( self, - column: ColumnName, - row_group: RowGroupIndex, - row_index: RowIndex | None, + column: str, + row_group: int, + row_index: int | None, row_group_size: int, - ) -> list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]: + ) -> list[CellRef]: """Derive cell-level deps on demand from column-level DAG + strategy. - Returns a list of ``(upstream_column, row_group, row_index)`` tuples - that must be complete before this task can run. + Returns a list of ``CellRef`` that must be complete before this task can run. """ - deps: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]] = [] - for up_col in self.upstream(column): + deps: list[CellRef] = [] + for up_col in self.get_upstream_columns(column): up_strategy = self._strategies[up_col] if up_strategy == GenerationStrategy.CELL_BY_CELL: if row_index is not None: - deps.append((up_col, row_group, row_index)) + deps.append(CellRef(up_col, row_group, row_index)) else: for ri in range(row_group_size): - deps.append((up_col, row_group, ri)) + deps.append(CellRef(up_col, row_group, ri)) else: - deps.append((up_col, row_group, None)) + deps.append(CellRef(up_col, row_group, None)) return deps def to_mermaid(self) -> str: @@ -199,7 +200,7 @@ def to_mermaid(self) -> str: def create( cls, column_configs: list[DatasetBuilderColumnConfigT], - strategies: dict[ColumnName, GenerationStrategy], + strategies: dict[str, GenerationStrategy], ) -> ExecutionGraph: """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. @@ -249,14 +250,3 @@ def create( graph.topological_order() return graph - - -def build_execution_graph( - column_configs: list[DatasetBuilderColumnConfigT], - strategies: dict[ColumnName, GenerationStrategy], -) -> ExecutionGraph: - """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. - - .. deprecated:: Use ``ExecutionGraph.create()`` instead. - """ - return ExecutionGraph.create(column_configs, strategies) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index be95c3e43..fca2964c5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -4,20 +4,24 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal, TypeAlias +from typing import Any, Literal, NamedTuple -ColumnName: TypeAlias = str -RowGroupIndex: TypeAlias = int -RowIndex: TypeAlias = int + +class CellRef(NamedTuple): + """Reference to a cell (or batch when row_index is None) in the dataset grid.""" + + column: str + row_group: int + row_index: int | None = None @dataclass(frozen=True) class Task: """A unit of work for the async scheduler.""" - column: ColumnName - row_group: RowGroupIndex - row_index: RowIndex | None # None for batch/full-column tasks + column: str + row_group: int + row_index: int | None # None for batch/full-column tasks task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"] @@ -36,9 +40,9 @@ class TaskResult: class TaskTrace: """Timing trace for a single task. Only created when tracing is enabled.""" - column: ColumnName - row_group: RowGroupIndex - row_index: RowIndex | None + column: str + row_group: int + row_index: int | None task_type: str dispatched_at: float = 0.0 slot_acquired_at: float = 0.0 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index c780a5532..5016d9258 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -15,7 +15,7 @@ from data_designer.config.sampler_params import SamplerType from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import Task +from data_designer.engine.dataset_builders.utils.task_model import CellRef, Task MODEL_ALIAS = "stub" @@ -62,12 +62,12 @@ def test_tracker_requires_graph_with_row_groups() -> None: CompletionTracker(graph=None, row_groups=[(0, 3)]) -# -- mark_complete / is_complete ------------------------------------------- +# -- mark_cell_complete / is_complete -------------------------------------- def test_mark_and_check_complete() -> None: tracker = CompletionTracker() - tracker.mark_complete("col_a", row_group=0, row_index=0) + tracker.mark_cell_complete("col_a", row_group=0, row_index=0) assert tracker.is_complete("col_a", 0, 0) assert not tracker.is_complete("col_a", 0, 1) @@ -75,9 +75,9 @@ def test_mark_and_check_complete() -> None: assert not tracker.is_complete("col_b", 0, 0) -def test_mark_batch_complete() -> None: +def test_mark_row_range_complete() -> None: tracker = CompletionTracker() - tracker.mark_batch_complete("col_a", row_group=0, row_group_size=3) + tracker.mark_row_range_complete("col_a", row_group=0, row_group_size=3) assert tracker.is_complete("col_a", 0, 0) assert tracker.is_complete("col_a", 0, 1) @@ -85,28 +85,46 @@ def test_mark_batch_complete() -> None: assert not tracker.is_complete("col_a", 0, 3) +def test_mark_row_range_complete_raises_on_size_mismatch(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="Row-group size mismatch"): + ready_ctx.tracker.mark_row_range_complete("topic", row_group=0, row_group_size=2) + + +def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="Unknown row_group"): + ready_ctx.tracker.mark_cell_complete("topic", row_group=999, row_index=0) + + # -- is_all_complete ----------------------------------------------------------- def test_all_complete_cell_level() -> None: tracker = CompletionTracker() - tracker.mark_complete("col_a", 0, 0) - tracker.mark_complete("col_a", 0, 1) + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 1) - assert tracker.is_all_complete([("col_a", 0, 0), ("col_a", 0, 1)]) - assert not tracker.is_all_complete([("col_a", 0, 0), ("col_a", 0, 2)]) + assert tracker.is_all_complete([CellRef("col_a", 0, 0), CellRef("col_a", 0, 1)]) + assert not tracker.is_all_complete([CellRef("col_a", 0, 0), CellRef("col_a", 0, 2)]) def test_all_complete_batch_level() -> None: tracker = CompletionTracker() - tracker.mark_batch_complete("col_a", 0, 3) + tracker.mark_row_range_complete("col_a", 0, 3) + + assert tracker.is_all_complete([CellRef("col_a", 0, None)]) + + +def test_all_complete_batch_single_cell_not_sufficient() -> None: + """mark_cell_complete on one row must NOT make is_all_complete return True for batch check.""" + tracker = CompletionTracker() + tracker.mark_cell_complete("col_a", 0, 0) - assert tracker.is_all_complete([("col_a", 0, None)]) + assert not tracker.is_all_complete([CellRef("col_a", 0, None)]) def test_all_complete_batch_not_present() -> None: tracker = CompletionTracker() - assert not tracker.is_all_complete([("col_a", 0, None)]) + assert not tracker.is_all_complete([CellRef("col_a", 0, None)]) def test_all_complete_empty_list() -> None: @@ -131,25 +149,25 @@ def test_drop_row() -> None: def test_row_group_complete() -> None: tracker = CompletionTracker() - tracker.mark_batch_complete("col_a", 0, 3) - tracker.mark_batch_complete("col_b", 0, 3) + tracker.mark_row_range_complete("col_a", 0, 3) + tracker.mark_row_range_complete("col_b", 0, 3) assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) def test_row_group_incomplete() -> None: tracker = CompletionTracker() - tracker.mark_batch_complete("col_a", 0, 3) + tracker.mark_row_range_complete("col_a", 0, 3) assert not tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) def test_row_group_complete_with_dropped_rows() -> None: tracker = CompletionTracker() - tracker.mark_complete("col_a", 0, 0) - tracker.mark_complete("col_a", 0, 2) - tracker.mark_complete("col_b", 0, 0) - tracker.mark_complete("col_b", 0, 2) + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_a", 0, 2) + tracker.mark_cell_complete("col_b", 0, 0) + tracker.mark_cell_complete("col_b", 0, 2) tracker.drop_row(0, 1) # row 1 is dropped assert tracker.is_row_group_complete(0, 3, ["col_a", "col_b"]) @@ -157,8 +175,8 @@ def test_row_group_complete_with_dropped_rows() -> None: def test_row_group_not_complete_missing_non_dropped() -> None: tracker = CompletionTracker() - tracker.mark_complete("col_a", 0, 0) - tracker.mark_complete("col_b", 0, 0) + tracker.mark_cell_complete("col_a", 0, 0) + tracker.mark_cell_complete("col_b", 0, 0) tracker.drop_row(0, 1) # row 2 is not dropped and not complete @@ -177,7 +195,7 @@ def test_get_ready_tasks_seeds_first(ready_ctx: ReadyTasksFixture) -> None: def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) @@ -188,7 +206,7 @@ def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> No def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready1 = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) ready_ctx.dispatched.update(ready1) @@ -198,7 +216,7 @@ def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready_ctx.tracker.drop_row(0, 1) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) @@ -209,9 +227,9 @@ def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> Non def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) - ready_ctx.tracker.mark_complete("question", 0, 0) - ready_ctx.tracker.mark_complete("question", 0, 1) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + ready_ctx.tracker.mark_cell_complete("question", 0, 0) + ready_ctx.tracker.mark_cell_complete("question", 0, 1) # question[0,2] not done yet ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) @@ -221,9 +239,9 @@ def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFi def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) for ri in range(3): - ready_ctx.tracker.mark_complete("question", 0, ri) + ready_ctx.tracker.mark_cell_complete("question", 0, ri) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) @@ -237,8 +255,8 @@ def test_get_ready_tasks_multiple_row_groups() -> None: tracker = CompletionTracker(graph, [(0, 3), (1, 2)]) dispatched: set[Task] = set() - tracker.mark_batch_complete("topic", 0, 3) - tracker.mark_batch_complete("topic", 1, 2) + tracker.mark_row_range_complete("topic", 0, 3) + tracker.mark_row_range_complete("topic", 1, 2) ready = tracker.get_ready_tasks(dispatched) @@ -247,7 +265,7 @@ def test_get_ready_tasks_multiple_row_groups() -> None: def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixture) -> None: - ready_ctx.tracker.mark_batch_complete("topic", 0, 3) + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 454397930..8af6d42ca 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -18,10 +18,8 @@ from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.execution_graph import ( - ExecutionGraph, - build_execution_graph, -) +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.engine.dataset_builders.utils.task_model import CellRef MODEL_ALIAS = "stub-model-alias" @@ -55,7 +53,7 @@ def simple_graph( simple_pipeline_configs: list, simple_pipeline_strategies: dict[str, GenerationStrategy], ) -> ExecutionGraph: - return build_execution_graph(simple_pipeline_configs, simple_pipeline_strategies) + return ExecutionGraph.create(simple_pipeline_configs, simple_pipeline_strategies) # -- Graph construction tests ------------------------------------------------ @@ -63,17 +61,17 @@ def simple_graph( def test_build_basic_graph(simple_graph: ExecutionGraph) -> None: assert simple_graph.columns == ["topic", "question", "answer", "score"] - assert simple_graph.upstream("topic") == set() - assert simple_graph.upstream("question") == {"topic"} - assert simple_graph.upstream("answer") == {"question"} - assert simple_graph.upstream("score") == {"answer"} + assert simple_graph.get_upstream_columns("topic") == set() + assert simple_graph.get_upstream_columns("question") == {"topic"} + assert simple_graph.get_upstream_columns("answer") == {"question"} + assert simple_graph.get_upstream_columns("score") == {"answer"} -def test_downstream(simple_graph: ExecutionGraph) -> None: - assert simple_graph.downstream("topic") == {"question"} - assert simple_graph.downstream("question") == {"answer"} - assert simple_graph.downstream("answer") == {"score"} - assert simple_graph.downstream("score") == set() +def test_get_downstream_columns(simple_graph: ExecutionGraph) -> None: + assert simple_graph.get_downstream_columns("topic") == {"question"} + assert simple_graph.get_downstream_columns("question") == {"answer"} + assert simple_graph.get_downstream_columns("answer") == {"score"} + assert simple_graph.get_downstream_columns("score") == set() def test_strategy(simple_graph: ExecutionGraph) -> None: @@ -81,14 +79,14 @@ def test_strategy(simple_graph: ExecutionGraph) -> None: assert simple_graph.strategy("question") == GenerationStrategy.CELL_BY_CELL -def test_unknown_column_upstream() -> None: +def test_unknown_column_get_upstream_columns() -> None: graph = ExecutionGraph() - assert graph.upstream("nonexistent") == set() + assert graph.get_upstream_columns("nonexistent") == set() -def test_unknown_column_downstream() -> None: +def test_unknown_column_get_downstream_columns() -> None: graph = ExecutionGraph() - assert graph.downstream("nonexistent") == set() + assert graph.get_downstream_columns("nonexistent") == set() # -- Side-effect resolution ------------------------------------------------- @@ -108,10 +106,10 @@ def test_side_effect_column_resolution() -> None: "summary": GenerationStrategy.CELL_BY_CELL, "trace_len": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("trace_len") == {"summary"} - assert graph.downstream("summary") == {"trace_len"} + assert graph.get_upstream_columns("trace_len") == {"summary"} + assert graph.get_downstream_columns("summary") == {"trace_len"} def test_reasoning_content_side_effect() -> None: @@ -128,9 +126,9 @@ def test_reasoning_content_side_effect() -> None: "answer": GenerationStrategy.CELL_BY_CELL, "reasoning": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("reasoning") == {"answer"} + assert graph.get_upstream_columns("reasoning") == {"answer"} def test_side_effect_name_collision_prefers_real_column() -> None: @@ -149,11 +147,11 @@ def test_side_effect_name_collision_prefers_real_column() -> None: "summary__trace": GenerationStrategy.FULL_COLUMN, "trace_len": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("trace_len") == {"summary__trace"} - assert graph.downstream("summary__trace") == {"trace_len"} - assert graph.downstream("summary") == set() + assert graph.get_upstream_columns("trace_len") == {"summary__trace"} + assert graph.get_downstream_columns("summary__trace") == {"trace_len"} + assert graph.get_downstream_columns("summary") == set() # -- Validation tests ------------------------------------------------------- @@ -169,7 +167,7 @@ def test_circular_dependency_raises() -> None: "col_b": GenerationStrategy.CELL_BY_CELL, } with pytest.raises(DAGCircularDependencyError): - build_execution_graph(configs, strategies) + ExecutionGraph.create(configs, strategies) def test_unknown_required_column_raises() -> None: @@ -178,7 +176,7 @@ def test_unknown_required_column_raises() -> None: ] strategies = {"col_a": GenerationStrategy.CELL_BY_CELL} with pytest.raises(ValueError, match="not a known producer"): - build_execution_graph(configs, strategies) + ExecutionGraph.create(configs, strategies) # -- Topological order ------------------------------------------------------ @@ -207,7 +205,7 @@ def test_parallel_columns_topological_order() -> None: "branch_b": GenerationStrategy.CELL_BY_CELL, "merge": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) order = graph.topological_order() idx = {col: i for i, col in enumerate(order)} @@ -220,12 +218,17 @@ def test_parallel_columns_topological_order() -> None: # -- Critical path ---------------------------------------------------------- -def test_critical_path(simple_graph: ExecutionGraph) -> None: - path = simple_graph.critical_path() +def test_get_longest_dependency_chain_empty_graph() -> None: + graph = ExecutionGraph() + assert graph.get_longest_dependency_chain() == [] + + +def test_get_longest_dependency_chain(simple_graph: ExecutionGraph) -> None: + path = simple_graph.get_longest_dependency_chain() assert path == ["topic", "question", "answer", "score"] -def test_critical_path_diamond() -> None: +def test_get_longest_dependency_chain_diamond() -> None: """Diamond: seed → (a, b) → merge. Path is seed → a/b → merge (length 3).""" configs = [ SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]}), @@ -239,8 +242,8 @@ def test_critical_path_diamond() -> None: "b": GenerationStrategy.CELL_BY_CELL, "merge": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) - path = graph.critical_path() + graph = ExecutionGraph.create(configs, strategies) + path = graph.get_longest_dependency_chain() assert len(path) == 3 assert path[0] == "seed" @@ -273,13 +276,13 @@ def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column); answer depends on question (cell-by-cell).""" # answer[rg=0, row=2] should depend on question[rg=0, row=2] deps = simple_graph.cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) - assert deps == [("question", 0, 2)] + assert deps == [CellRef("question", 0, 2)] def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column).""" deps = simple_graph.cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) - assert deps == [("topic", 0, None)] + assert deps == [CellRef("topic", 0, None)] def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: @@ -291,7 +294,7 @@ def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" deps = simple_graph.cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) - assert sorted(deps) == [("answer", 0, 0), ("answer", 0, 1), ("answer", 0, 2)] + assert sorted(deps) == [CellRef("answer", 0, 0), CellRef("answer", 0, 1), CellRef("answer", 0, 2)] # -- Mermaid output ---------------------------------------------------------- @@ -324,11 +327,11 @@ def test_multi_column_config() -> None: "first_name": GenerationStrategy.FULL_COLUMN, "last_name": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) assert set(graph.columns) == {"first_name", "last_name"} - assert graph.upstream("first_name") == set() - assert graph.upstream("last_name") == set() + assert graph.get_upstream_columns("first_name") == set() + assert graph.get_upstream_columns("last_name") == set() def test_multi_column_with_downstream_dependency() -> None: @@ -349,9 +352,9 @@ def test_multi_column_with_downstream_dependency() -> None: "last_name": GenerationStrategy.FULL_COLUMN, "greeting": GenerationStrategy.CELL_BY_CELL, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("greeting") == {"first_name", "last_name"} + assert graph.get_upstream_columns("greeting") == {"first_name", "last_name"} # -- Validation column dependency ------------------------------------------- @@ -376,10 +379,10 @@ def test_validation_column_dependency() -> None: "code": GenerationStrategy.CELL_BY_CELL, "validation": GenerationStrategy.FULL_COLUMN, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("validation") == {"code"} - assert graph.downstream("code") == {"validation"} + assert graph.get_upstream_columns("validation") == {"code"} + assert graph.get_downstream_columns("code") == {"validation"} # -- Judge column dependency ------------------------------------------------ @@ -399,6 +402,6 @@ def test_judge_column_dependency() -> None: "text": GenerationStrategy.CELL_BY_CELL, "judge": GenerationStrategy.CELL_BY_CELL, } - graph = build_execution_graph(configs, strategies) + graph = ExecutionGraph.create(configs, strategies) - assert graph.upstream("judge") == {"text"} + assert graph.get_upstream_columns("judge") == {"text"} From 82b8351d09df671c47d45732aa1512fd9c3c73ec Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 11:09:33 -0300 Subject: [PATCH 09/17] fix: prevent completed tasks from re-entering the frontier Skip adding downstream tasks to the frontier when they are already marked complete, avoiding redundant work in CompletionTracker. --- .../dataset_builders/utils/completion_tracker.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index b4e90c346..d98fced81 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -97,7 +97,12 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups] if row_index is not None: # Cell completion: only check the same row - if row_index not in rg_dropped and all(row_index in s for s in cell_up_completed): + down_completed = rg_completed.get(down, set()) + if ( + row_index not in rg_dropped + and row_index not in down_completed + and all(row_index in s for s in cell_up_completed) + ): task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell") self._frontier.add(task) else: @@ -111,7 +116,9 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None self._frontier.add(task) else: # FULL_COLUMN downstream: ready when all cell upstreams are fully complete - if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): + if down not in rg_completed and self._are_cell_ups_complete( + cell_ups, rg_completed, rg_size, rg_dropped + ): task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") self._frontier.add(task) From 3b539b6c6c68938a46d2df9225f479f168b4279b Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 11:29:03 -0300 Subject: [PATCH 10/17] harden completion tracker and execution graph APIs - Enforce strategy-safe completion: mark_cell_complete rejects non-CELL_BY_CELL columns, mark_row_range_complete rejects CELL_BY_CELL columns (ValueError in graph mode) - Return defensive copies from ExecutionGraph public API (columns, get_upstream/downstream_columns) - Add re-enqueue regression tests for cell and batch paths - Add immutability tests for ExecutionGraph collections --- .../utils/completion_tracker.py | 10 +++ .../dataset_builders/utils/execution_graph.py | 8 +- .../utils/test_completion_tracker.py | 75 +++++++++++++++++++ .../utils/test_execution_graph.py | 31 ++++++++ 4 files changed, 120 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index d98fced81..6f6919203 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -62,6 +62,7 @@ def _seed_frontier(self) -> None: def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) + self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete") self._completed[row_group][column].add(row_index) if self._graph is not None: self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) @@ -69,6 +70,7 @@ def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> Non def mark_row_range_complete(self, column: str, row_group: int, row_group_size: int) -> None: expected = self._validate_row_group(row_group) + self._validate_strategy(column, GenerationStrategy.FULL_COLUMN, "mark_row_range_complete") if expected is not None and row_group_size != expected: raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") self._completed[row_group][column] = set(range(row_group_size)) @@ -210,6 +212,14 @@ def _are_cell_ups_complete( return False return True + def _validate_strategy(self, column: str, expected: GenerationStrategy, method: str) -> None: + """Validate that *column* matches the expected strategy in graph-enabled mode.""" + if self._graph is None: + return + actual = self._graph.strategy(column) + if actual != expected: + raise ValueError(f"{method}() requires {expected.value} strategy, but column '{column}' has {actual.value}") + def _validate_row_group(self, row_group: int) -> int | None: """Validate row-group id in graph-enabled mode and return its expected size.""" if self._graph is None: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 684621541..f00446280 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -58,11 +58,11 @@ def resolve_side_effect(self, column: str) -> str: def get_upstream_columns(self, column: str) -> set[str]: """Direct dependencies of *column*.""" - return self._upstream.get(column, set()) + return set(self._upstream.get(column, set())) def get_downstream_columns(self, column: str) -> set[str]: """Columns that depend on *column*.""" - return self._downstream.get(column, set()) + return set(self._downstream.get(column, set())) def strategy(self, column: str) -> GenerationStrategy: return self._strategies[column] @@ -85,8 +85,8 @@ def upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: @property def columns(self) -> list[str]: - """All column names in insertion order. Do not mutate.""" - return self._columns + """All column names in insertion order.""" + return list(self._columns) def topological_order(self) -> list[str]: """Return a valid topological ordering of columns (Kahn's algorithm). diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 5016d9258..138c48c3b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -271,3 +271,78 @@ def test_get_ready_tasks_skips_already_complete_batch(ready_ctx: ReadyTasksFixtu topic_tasks = [t for t in ready if t.column == "topic"] assert len(topic_tasks) == 0 + + +# -- Strategy-safe completion API ------------------------------------------ + + +def test_mark_cell_complete_raises_for_full_column_strategy(ready_ctx: ReadyTasksFixture) -> None: + with pytest.raises(ValueError, match="mark_cell_complete.*requires cell_by_cell.*full_column"): + ready_ctx.tracker.mark_cell_complete("topic", row_group=0, row_index=0) + + +def test_mark_row_range_complete_raises_for_cell_by_cell_strategy(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + with pytest.raises(ValueError, match="mark_row_range_complete.*requires full_column.*cell_by_cell"): + ready_ctx.tracker.mark_row_range_complete("question", row_group=0, row_group_size=3) + + +# -- Re-enqueue regression tests ------------------------------------------- + + +def test_completed_cell_not_reenqueued_after_later_upstream(ready_ctx: ReadyTasksFixture) -> None: + """A → B → C chain: completing C[row=0] then completing another A upstream must not re-enqueue C[row=0].""" + graph = _build_simple_graph() + tracker = CompletionTracker(graph, [(0, 2)]) + dispatched: set[Task] = set() + + # Complete the full pipeline for row 0 + tracker.mark_row_range_complete("topic", 0, 2) + tracker.mark_cell_complete("question", 0, 0) + tracker.mark_cell_complete("question", 0, 1) + + # score should now be ready + ready = tracker.get_ready_tasks(dispatched) + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 1 + + # Complete score, then re-complete an upstream cell — score must not reappear + tracker.mark_row_range_complete("score", 0, 2) + + ready = tracker.get_ready_tasks(dispatched) + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 0 + + +def test_completed_batch_not_reenqueued_by_upstream_cell() -> None: + """After a FULL_COLUMN downstream is completed, a late cell upstream event must not re-add it.""" + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="gen", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="agg", expr="{{ gen }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "gen": GenerationStrategy.CELL_BY_CELL, + "agg": GenerationStrategy.FULL_COLUMN, + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker(graph, [(0, 2)]) + dispatched: set[Task] = set() + + # Complete seed and gen[0] — agg not ready yet + tracker.mark_row_range_complete("seed", 0, 2) + tracker.mark_cell_complete("gen", 0, 0) + + ready = tracker.get_ready_tasks(dispatched) + assert not any(t.column == "agg" for t in ready) + + # Complete gen[1] — agg becomes ready + tracker.mark_cell_complete("gen", 0, 1) + ready = tracker.get_ready_tasks(dispatched) + assert any(t.column == "agg" for t in ready) + + # Complete agg, then verify it doesn't reappear + tracker.mark_row_range_complete("agg", 0, 2) + ready = tracker.get_ready_tasks(dispatched) + assert not any(t.column == "agg" for t in ready) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 8af6d42ca..0dda6a81d 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -388,6 +388,37 @@ def test_validation_column_dependency() -> None: # -- Judge column dependency ------------------------------------------------ +# -- Immutability tests ----------------------------------------------------- + + +def test_mutating_columns_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + cols = simple_graph.columns + cols.append("injected") + assert "injected" not in simple_graph.columns + + +def test_mutating_upstream_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + ups = simple_graph.get_upstream_columns("question") + ups.add("injected") + assert "injected" not in simple_graph.get_upstream_columns("question") + + +def test_mutating_downstream_does_not_affect_graph(simple_graph: ExecutionGraph) -> None: + downs = simple_graph.get_downstream_columns("topic") + downs.add("injected") + assert "injected" not in simple_graph.get_downstream_columns("topic") + + +def test_mutating_topological_order_does_not_affect_cache(simple_graph: ExecutionGraph) -> None: + order1 = simple_graph.topological_order() + order1.reverse() + order2 = simple_graph.topological_order() + assert order2[0] == "topic" + + +# -- Judge column dependency ------------------------------------------------ + + def test_judge_column_dependency() -> None: configs = [ LLMTextColumnConfig(name="text", prompt="Write something", model_alias=MODEL_ALIAS), From 03741df20ef45a07821841a2030bb2c2f67c7108 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 13:22:10 -0300 Subject: [PATCH 11/17] address second round of review feedback - Reject duplicate column names in add_column with ValueError - Validate buffer_size > 0 in task_count - Use _batch_complete for batch upstream readiness checks - Remove duplicate section header in test file --- .../engine/dataset_builders/utils/completion_tracker.py | 7 ++++--- .../engine/dataset_builders/utils/execution_graph.py | 4 ++++ .../engine/dataset_builders/utils/test_execution_graph.py | 3 --- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 6f6919203..90ed1e19a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -84,13 +84,13 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None assert self._graph is not None rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, set()) rg_size = self._row_group_sizes[row_group] for down in self._graph.get_downstream_columns(column): batch_ups, cell_ups = self._graph.upstream_by_strategy(down) - # All batch upstreams must be present in completed dict - if any(up not in rg_completed for up in batch_ups): + if any(up not in rg_batch_complete for up in batch_ups): continue down_strategy = self._graph.strategy(down) @@ -156,6 +156,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None: assert self._graph is not None rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) + rg_batch_complete = self._batch_complete.get(row_group, set()) rg_size = self._row_group_sizes[row_group] for col in self._graph.topological_order(): @@ -164,7 +165,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None: if col in rg_completed: continue batch_ups, cell_ups = self._graph.upstream_by_strategy(col) - if any(up not in rg_completed for up in batch_ups): + if any(up not in rg_batch_complete for up in batch_ups): continue if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index f00446280..c993de3a4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -34,6 +34,8 @@ def __init__(self) -> None: def add_column(self, name: str, strategy: GenerationStrategy) -> None: """Register a column and its generation strategy.""" + if name in self._strategies: + raise ValueError(f"Column '{name}' is already registered.") self._columns.append(name) self._strategies[name] = strategy @@ -150,6 +152,8 @@ def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: Cell-by-cell columns produce ``num_records`` tasks. Full-column columns (including from-scratch) produce ``ceil(num_records / buffer_size)`` tasks. """ + if buffer_size <= 0: + raise ValueError(f"buffer_size must be a positive integer, got {buffer_size}") num_row_groups = math.ceil(num_records / buffer_size) counts: dict[str, int] = {} for col in self._columns: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 0dda6a81d..150a2a5e7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -385,9 +385,6 @@ def test_validation_column_dependency() -> None: assert graph.get_downstream_columns("code") == {"validation"} -# -- Judge column dependency ------------------------------------------------ - - # -- Immutability tests ----------------------------------------------------- From f4b10c1807d35e453c4e7bbb869b1c23dd9eabdd Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 18:19:41 -0300 Subject: [PATCH 12/17] fix AGENTS.md compliance violations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add `from __future__ import annotations` to 5 files missing it - Rename ExecutionGraph methods to start with action verbs (strategy → get_strategy, topological_order → get_topological_order, upstream_by_strategy → split_upstream_by_strategy, task_count → compute_task_count, cell_dependencies → compute_cell_dependencies) - Reorder methods in CompletionTracker and ExecutionGraph: __init__ → properties → classmethods → public → private --- .../utils/completion_tracker.py | 148 +++++++++--------- .../dataset_builders/utils/execution_graph.py | 132 ++++++++-------- .../engine/models/clients/__init__.py | 2 + .../models/clients/adapters/__init__.py | 2 + .../utils/test_completion_tracker.py | 1 + .../utils/test_execution_graph.py | 26 +-- .../dataset_builders/utils/test_task_model.py | 2 + 7 files changed, 161 insertions(+), 152 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 90ed1e19a..eda6b6000 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -46,20 +46,6 @@ def __init__( self._row_group_sizes = {rg_id: size for rg_id, size in row_groups} self._seed_frontier() - def _seed_frontier(self) -> None: - """Populate the frontier with root tasks (columns with no upstream deps).""" - assert self._graph is not None - for col in self._graph.topological_order(): - if self._graph.get_upstream_columns(col): - continue - strategy = self._graph.strategy(col) - for rg_id, rg_size in self._row_group_sizes.items(): - if strategy == GenerationStrategy.CELL_BY_CELL: - for ri in range(rg_size): - self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) - else: - self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) - def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) self._validate_strategy(column, GenerationStrategy.CELL_BY_CELL, "mark_cell_complete") @@ -79,6 +65,74 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) self._enqueue_downstream(column, row_group, row_index=None) + def is_complete(self, column: str, row_group: int, row_index: int) -> bool: + return row_index in self._completed.get(row_group, {}).get(column, set()) + + def is_all_complete(self, cells: list[CellRef]) -> bool: + """Check whether all the given cells are done. + + A ``row_index`` of ``None`` means the entire batch for that column must + have been completed via ``mark_row_range_complete``. + """ + for col, rg, ri in cells: + if ri is None: + if col not in self._batch_complete.get(rg, set()): + return False + elif not self.is_complete(col, rg, ri): + return False + return True + + def drop_row(self, row_group: int, row_index: int) -> None: + self._validate_row_group(row_group) + self._dropped[row_group].add(row_index) + if self._graph is not None: + # Remove cell tasks for this row from the frontier + for col in self._graph.columns: + self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) + # Dropping a row may unblock batch downstream tasks + self._reevaluate_batch_tasks(row_group) + + def is_dropped(self, row_group: int, row_index: int) -> bool: + return row_index in self._dropped.get(row_group, set()) + + def is_row_group_complete( + self, + row_group: int, + row_group_size: int, + all_columns: list[str], + ) -> bool: + """All non-dropped rows have all columns done.""" + dropped = self._dropped.get(row_group, set()) + completed = self._completed.get(row_group, {}) + for ri in range(row_group_size): + if ri in dropped: + continue + for col in all_columns: + if ri not in completed.get(col, set()): + return False + return True + + def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: + """Return all currently dispatchable tasks from the frontier. + + Excludes already-dispatched/in-flight tasks. + """ + return [t for t in self._frontier if t not in dispatched] + + def _seed_frontier(self) -> None: + """Populate the frontier with root tasks (columns with no upstream deps).""" + assert self._graph is not None + for col in self._graph.get_topological_order(): + if self._graph.get_upstream_columns(col): + continue + strategy = self._graph.get_strategy(col) + for rg_id, rg_size in self._row_group_sizes.items(): + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(rg_size): + self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) + else: + self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) + def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None: """Add newly-ready downstream tasks to the frontier.""" assert self._graph is not None @@ -88,12 +142,12 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None rg_size = self._row_group_sizes[row_group] for down in self._graph.get_downstream_columns(column): - batch_ups, cell_ups = self._graph.upstream_by_strategy(down) + batch_ups, cell_ups = self._graph.split_upstream_by_strategy(down) if any(up not in rg_batch_complete for up in batch_ups): continue - down_strategy = self._graph.strategy(down) + down_strategy = self._graph.get_strategy(down) if down_strategy == GenerationStrategy.CELL_BY_CELL: cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups] @@ -124,33 +178,6 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") self._frontier.add(task) - def is_complete(self, column: str, row_group: int, row_index: int) -> bool: - return row_index in self._completed.get(row_group, {}).get(column, set()) - - def is_all_complete(self, cells: list[CellRef]) -> bool: - """Check whether all the given cells are done. - - A ``row_index`` of ``None`` means the entire batch for that column must - have been completed via ``mark_row_range_complete``. - """ - for col, rg, ri in cells: - if ri is None: - if col not in self._batch_complete.get(rg, set()): - return False - elif not self.is_complete(col, rg, ri): - return False - return True - - def drop_row(self, row_group: int, row_index: int) -> None: - self._validate_row_group(row_group) - self._dropped[row_group].add(row_index) - if self._graph is not None: - # Remove cell tasks for this row from the frontier - for col in self._graph.columns: - self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) - # Dropping a row may unblock batch downstream tasks - self._reevaluate_batch_tasks(row_group) - def _reevaluate_batch_tasks(self, row_group: int) -> None: """Check if any batch tasks became ready after a row was dropped.""" assert self._graph is not None @@ -159,45 +186,18 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None: rg_batch_complete = self._batch_complete.get(row_group, set()) rg_size = self._row_group_sizes[row_group] - for col in self._graph.topological_order(): - if self._graph.strategy(col) != GenerationStrategy.FULL_COLUMN: + for col in self._graph.get_topological_order(): + if self._graph.get_strategy(col) != GenerationStrategy.FULL_COLUMN: continue if col in rg_completed: continue - batch_ups, cell_ups = self._graph.upstream_by_strategy(col) + batch_ups, cell_ups = self._graph.split_upstream_by_strategy(col) if any(up not in rg_batch_complete for up in batch_ups): continue if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") self._frontier.add(task) - def is_dropped(self, row_group: int, row_index: int) -> bool: - return row_index in self._dropped.get(row_group, set()) - - def is_row_group_complete( - self, - row_group: int, - row_group_size: int, - all_columns: list[str], - ) -> bool: - """All non-dropped rows have all columns done.""" - dropped = self._dropped.get(row_group, set()) - completed = self._completed.get(row_group, {}) - for ri in range(row_group_size): - if ri in dropped: - continue - for col in all_columns: - if ri not in completed.get(col, set()): - return False - return True - - def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: - """Return all currently dispatchable tasks from the frontier. - - Excludes already-dispatched/in-flight tasks. - """ - return [t for t in self._frontier if t not in dispatched] - def _are_cell_ups_complete( self, cell_ups: list[str], @@ -217,7 +217,7 @@ def _validate_strategy(self, column: str, expected: GenerationStrategy, method: """Validate that *column* matches the expected strategy in graph-enabled mode.""" if self._graph is None: return - actual = self._graph.strategy(column) + actual = self._graph.get_strategy(column) if actual != expected: raise ValueError(f"{method}() requires {expected.value} strategy, but column '{column}' has {actual.value}") diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index c993de3a4..35e0d4e40 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -32,6 +32,66 @@ def __init__(self) -> None: self._topological_order_cache: list[str] | None = None self._upstream_by_strategy_cache: dict[str, tuple[list[str], list[str]]] = {} + @property + def columns(self) -> list[str]: + """All column names in insertion order.""" + return list(self._columns) + + @classmethod + def create( + cls, + column_configs: list[DatasetBuilderColumnConfigT], + strategies: dict[str, GenerationStrategy], + ) -> ExecutionGraph: + """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. + + Args: + column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. + strategies: Map of column name → ``GenerationStrategy``, obtained from + each generator's ``get_generation_strategy()``. + """ + graph = cls() + + # First pass: register all columns, strategies, and side-effect mappings + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + graph.add_column(name, strategies[name]) + + for se_col in sub.side_effect_columns: + graph.set_side_effect(se_col, name) + + known_columns = set(graph.columns) + + # Second pass: build edges + for config in column_configs: + if isinstance(config, MultiColumnConfig): + sub_configs = config.columns + else: + sub_configs = [config] + + for sub in sub_configs: + name = sub.name + for req in sub.required_columns: + resolved = graph.resolve_side_effect(req) + if resolved not in known_columns: + raise ValueError( + f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." + ) + if resolved == name: + continue # skip self-dependency + graph.add_edge(upstream=resolved, downstream=name) + + # Validate acyclicity + graph.get_topological_order() + + return graph + def add_column(self, name: str, strategy: GenerationStrategy) -> None: """Register a column and its generation strategy.""" if name in self._strategies: @@ -66,10 +126,10 @@ def get_downstream_columns(self, column: str) -> set[str]: """Columns that depend on *column*.""" return set(self._downstream.get(column, set())) - def strategy(self, column: str) -> GenerationStrategy: + def get_strategy(self, column: str) -> GenerationStrategy: return self._strategies[column] - def upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: + def split_upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: """Split upstream columns into (batch, cell_by_cell) by strategy. Cached.""" cached = self._upstream_by_strategy_cache.get(column) if cached is not None: @@ -85,12 +145,7 @@ def upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: self._upstream_by_strategy_cache[column] = result return result - @property - def columns(self) -> list[str]: - """All column names in insertion order.""" - return list(self._columns) - - def topological_order(self) -> list[str]: + def get_topological_order(self) -> list[str]: """Return a valid topological ordering of columns (Kahn's algorithm). Result is cached after first successful computation since the graph is @@ -125,7 +180,7 @@ def topological_order(self) -> list[str]: def get_longest_dependency_chain(self) -> list[str]: """Longest dependency chain (by number of columns).""" - order = self.topological_order() + order = self.get_topological_order() if not order: return [] dist: dict[str, int] = {col: 0 for col in order} @@ -146,7 +201,7 @@ def get_longest_dependency_chain(self) -> list[str]: path.reverse() return path - def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: + def compute_task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: """Exact task count per column before the run starts. Cell-by-cell columns produce ``num_records`` tasks. @@ -164,7 +219,7 @@ def task_count(self, num_records: int, buffer_size: int) -> dict[str, int]: counts[col] = num_row_groups return counts - def cell_dependencies( + def compute_cell_dependencies( self, column: str, row_group: int, @@ -199,58 +254,3 @@ def to_mermaid(self) -> str: for dep in sorted(self._upstream.get(col, set())): lines.append(f" {dep} --> {col}") return "\n".join(lines) - - @classmethod - def create( - cls, - column_configs: list[DatasetBuilderColumnConfigT], - strategies: dict[str, GenerationStrategy], - ) -> ExecutionGraph: - """Build an ``ExecutionGraph`` from column configs and pre-computed strategies. - - Args: - column_configs: Ordered list of ``ColumnConfigT`` or ``MultiColumnConfig``. - strategies: Map of column name → ``GenerationStrategy``, obtained from - each generator's ``get_generation_strategy()``. - """ - graph = cls() - - # First pass: register all columns, strategies, and side-effect mappings - for config in column_configs: - if isinstance(config, MultiColumnConfig): - sub_configs = config.columns - else: - sub_configs = [config] - - for sub in sub_configs: - name = sub.name - graph.add_column(name, strategies[name]) - - for se_col in sub.side_effect_columns: - graph.set_side_effect(se_col, name) - - known_columns = set(graph.columns) - - # Second pass: build edges - for config in column_configs: - if isinstance(config, MultiColumnConfig): - sub_configs = config.columns - else: - sub_configs = [config] - - for sub in sub_configs: - name = sub.name - for req in sub.required_columns: - resolved = graph.resolve_side_effect(req) - if resolved not in known_columns: - raise ValueError( - f"Column '{name}' requires '{req}' (resolved to '{resolved}') which is not a known producer." - ) - if resolved == name: - continue # skip self-dependency - graph.add_edge(upstream=resolved, downstream=name) - - # Validate acyclicity - graph.topological_order() - - return graph diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index dec52401a..9cd9bcc62 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( ProviderError, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py index 1b65e2dde..cc9feefea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient, LiteLLMRouter __all__ = ["LiteLLMBridgeClient", "LiteLLMRouter"] diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 138c48c3b..0e163f071 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations from dataclasses import dataclass diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 150a2a5e7..3433de7ba 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import pytest from data_designer.config.column_configs import ( @@ -75,8 +77,8 @@ def test_get_downstream_columns(simple_graph: ExecutionGraph) -> None: def test_strategy(simple_graph: ExecutionGraph) -> None: - assert simple_graph.strategy("topic") == GenerationStrategy.FULL_COLUMN - assert simple_graph.strategy("question") == GenerationStrategy.CELL_BY_CELL + assert simple_graph.get_strategy("topic") == GenerationStrategy.FULL_COLUMN + assert simple_graph.get_strategy("question") == GenerationStrategy.CELL_BY_CELL def test_unknown_column_get_upstream_columns() -> None: @@ -183,7 +185,7 @@ def test_unknown_required_column_raises() -> None: def test_topological_order(simple_graph: ExecutionGraph) -> None: - order = simple_graph.topological_order() + order = simple_graph.get_topological_order() idx = {col: i for i, col in enumerate(order)} assert idx["topic"] < idx["question"] @@ -206,7 +208,7 @@ def test_parallel_columns_topological_order() -> None: "merge": GenerationStrategy.FULL_COLUMN, } graph = ExecutionGraph.create(configs, strategies) - order = graph.topological_order() + order = graph.get_topological_order() idx = {col: i for i, col in enumerate(order)} assert idx["seed"] < idx["branch_a"] @@ -254,7 +256,7 @@ def test_get_longest_dependency_chain_diamond() -> None: def test_task_count(simple_graph: ExecutionGraph) -> None: - counts = simple_graph.task_count(num_records=10, buffer_size=3) + counts = simple_graph.compute_task_count(num_records=10, buffer_size=3) assert counts["topic"] == 4 # ceil(10/3) = 4 row groups assert counts["question"] == 10 # cell-by-cell @@ -263,7 +265,7 @@ def test_task_count(simple_graph: ExecutionGraph) -> None: def test_task_count_exact_divisor(simple_graph: ExecutionGraph) -> None: - counts = simple_graph.task_count(num_records=9, buffer_size=3) + counts = simple_graph.compute_task_count(num_records=9, buffer_size=3) assert counts["topic"] == 3 assert counts["question"] == 9 @@ -275,25 +277,25 @@ def test_task_count_exact_divisor(simple_graph: ExecutionGraph) -> None: def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column); answer depends on question (cell-by-cell).""" # answer[rg=0, row=2] should depend on question[rg=0, row=2] - deps = simple_graph.cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) + deps = simple_graph.compute_cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) assert deps == [CellRef("question", 0, 2)] def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column).""" - deps = simple_graph.cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) + deps = simple_graph.compute_cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) assert deps == [CellRef("topic", 0, None)] def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: """topic has no upstream.""" - deps = simple_graph.cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) + deps = simple_graph.compute_cell_dependencies("topic", row_group=0, row_index=None, row_group_size=5) assert deps == [] def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" - deps = simple_graph.cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) + deps = simple_graph.compute_cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) assert sorted(deps) == [CellRef("answer", 0, 0), CellRef("answer", 0, 1), CellRef("answer", 0, 2)] @@ -407,9 +409,9 @@ def test_mutating_downstream_does_not_affect_graph(simple_graph: ExecutionGraph) def test_mutating_topological_order_does_not_affect_cache(simple_graph: ExecutionGraph) -> None: - order1 = simple_graph.topological_order() + order1 = simple_graph.get_topological_order() order1.reverse() - order2 = simple_graph.topological_order() + order2 = simple_graph.get_topological_order() assert order2[0] == "topic" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py index f9e2f9fc4..5d5716213 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import pytest from data_designer.engine.dataset_builders.utils.task_model import Task, TaskResult, TaskTrace From 15fef7f3fe78ad3f40bdd23e27c596e16354db35 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 18:51:17 -0300 Subject: [PATCH 13/17] address review feedback: CellRef dataclass, batch done-guard, is_complete API - Convert CellRef from NamedTuple to frozen dataclass - Change is_complete to accept CellRef instead of 3 positional args - Unify batch done-guards in _enqueue_downstream and _reevaluate_batch_tasks to use rg_batch_complete instead of rg_completed --- .../dataset_builders/utils/completion_tracker.py | 16 ++++++++-------- .../engine/dataset_builders/utils/task_model.py | 5 +++-- .../utils/test_completion_tracker.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index eda6b6000..d395e4e71 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -65,8 +65,8 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) self._enqueue_downstream(column, row_group, row_index=None) - def is_complete(self, column: str, row_group: int, row_index: int) -> bool: - return row_index in self._completed.get(row_group, {}).get(column, set()) + def is_complete(self, cell: CellRef) -> bool: + return cell.row_index in self._completed.get(cell.row_group, {}).get(cell.column, set()) def is_all_complete(self, cells: list[CellRef]) -> bool: """Check whether all the given cells are done. @@ -74,11 +74,11 @@ def is_all_complete(self, cells: list[CellRef]) -> bool: A ``row_index`` of ``None`` means the entire batch for that column must have been completed via ``mark_row_range_complete``. """ - for col, rg, ri in cells: - if ri is None: - if col not in self._batch_complete.get(rg, set()): + for cell in cells: + if cell.row_index is None: + if cell.column not in self._batch_complete.get(cell.row_group, set()): return False - elif not self.is_complete(col, rg, ri): + elif not self.is_complete(cell): return False return True @@ -172,7 +172,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None self._frontier.add(task) else: # FULL_COLUMN downstream: ready when all cell upstreams are fully complete - if down not in rg_completed and self._are_cell_ups_complete( + if down not in rg_batch_complete and self._are_cell_ups_complete( cell_ups, rg_completed, rg_size, rg_dropped ): task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") @@ -189,7 +189,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> None: for col in self._graph.get_topological_order(): if self._graph.get_strategy(col) != GenerationStrategy.FULL_COLUMN: continue - if col in rg_completed: + if col in rg_batch_complete: continue batch_ups, cell_ups = self._graph.split_upstream_by_strategy(col) if any(up not in rg_batch_complete for up in batch_ups): diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index fca2964c5..a63b4c7fa 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -4,10 +4,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Literal, NamedTuple +from typing import Any, Literal -class CellRef(NamedTuple): +@dataclass(frozen=True, order=True) +class CellRef: """Reference to a cell (or batch when row_index is None) in the dataset grid.""" column: str diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 0e163f071..bb716e24d 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -70,20 +70,20 @@ def test_mark_and_check_complete() -> None: tracker = CompletionTracker() tracker.mark_cell_complete("col_a", row_group=0, row_index=0) - assert tracker.is_complete("col_a", 0, 0) - assert not tracker.is_complete("col_a", 0, 1) - assert not tracker.is_complete("col_a", 1, 0) - assert not tracker.is_complete("col_b", 0, 0) + assert tracker.is_complete(CellRef("col_a", 0, 0)) + assert not tracker.is_complete(CellRef("col_a", 0, 1)) + assert not tracker.is_complete(CellRef("col_a", 1, 0)) + assert not tracker.is_complete(CellRef("col_b", 0, 0)) def test_mark_row_range_complete() -> None: tracker = CompletionTracker() tracker.mark_row_range_complete("col_a", row_group=0, row_group_size=3) - assert tracker.is_complete("col_a", 0, 0) - assert tracker.is_complete("col_a", 0, 1) - assert tracker.is_complete("col_a", 0, 2) - assert not tracker.is_complete("col_a", 0, 3) + assert tracker.is_complete(CellRef("col_a", 0, 0)) + assert tracker.is_complete(CellRef("col_a", 0, 1)) + assert tracker.is_complete(CellRef("col_a", 0, 2)) + assert not tracker.is_complete(CellRef("col_a", 0, 3)) def test_mark_row_range_complete_raises_on_size_mismatch(ready_ctx: ReadyTasksFixture) -> None: From 5d599e89d245b99a331a066f69e0bb0d2267511a Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 5 Mar 2026 19:04:10 -0300 Subject: [PATCH 14/17] fix test to use cell_by_cell column for row-group validation test --- .../engine/dataset_builders/utils/test_completion_tracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index bb716e24d..225921ac8 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -93,7 +93,7 @@ def test_mark_row_range_complete_raises_on_size_mismatch(ready_ctx: ReadyTasksFi def test_mark_cell_complete_raises_on_unknown_row_group(ready_ctx: ReadyTasksFixture) -> None: with pytest.raises(ValueError, match="Unknown row_group"): - ready_ctx.tracker.mark_cell_complete("topic", row_group=999, row_index=0) + ready_ctx.tracker.mark_cell_complete("question", row_group=999, row_index=0) # -- is_all_complete ----------------------------------------------------------- From db5db2a3b6b7255c1cd503e7249164a11147ba9a Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 6 Mar 2026 11:45:55 -0300 Subject: [PATCH 15/17] address review feedback: constructor, assertions, root columns, test fix - Split CompletionTracker into __init__() + with_graph() classmethod - Replace assert with RuntimeError in private methods - Add get_root_columns() to ExecutionGraph - Remove "no locks needed" from docstring - Fix re-enqueue regression test to exercise the actual scenario - Remove unused ready_ctx fixture parameter --- .../utils/completion_tracker.py | 42 +++++++++---------- .../dataset_builders/utils/execution_graph.py | 4 ++ .../utils/test_completion_tracker.py | 35 +++++----------- 3 files changed, 34 insertions(+), 47 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index d395e4e71..8fdbe7f15 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -16,35 +16,32 @@ class CompletionTracker: """Tracks which cells (column, row_group, row_index) are done. - All access is from the single asyncio event loop thread — no locks needed. Row indices are local to their row group (0-based). - When *graph* and *row_groups* are provided, an event-driven frontier is - maintained so that ``get_ready_tasks`` returns in O(frontier) instead of - scanning all columns × rows × row groups. + Use ``with_graph`` to create a frontier-enabled tracker where + ``get_ready_tasks`` returns in O(frontier) instead of scanning all + columns x rows x row groups. """ - def __init__( - self, - graph: ExecutionGraph | None = None, - row_groups: list[tuple[int, int]] | None = None, - ) -> None: - if (graph is None) != (row_groups is None): - raise ValueError("`graph` and `row_groups` must be provided together.") - + def __init__(self) -> None: # row_group → column → set of completed local row indices self._completed: dict[int, dict[str, set[int]]] = defaultdict(lambda: defaultdict(set)) # row_group → set of dropped row indices self._dropped: dict[int, set[int]] = defaultdict(set) - self._graph = graph + self._graph: ExecutionGraph | None = None self._row_group_sizes: dict[int, int] = {} self._batch_complete: dict[int, set[str]] = defaultdict(set) self._frontier: set[Task] = set() - if graph is not None and row_groups is not None: - self._row_group_sizes = {rg_id: size for rg_id, size in row_groups} - self._seed_frontier() + @classmethod + def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> CompletionTracker: + """Create a frontier-enabled tracker backed by an execution graph.""" + tracker = cls() + tracker._graph = graph + tracker._row_group_sizes = {rg_id: size for rg_id, size in row_groups} + tracker._seed_frontier() + return tracker def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) @@ -121,10 +118,9 @@ def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: def _seed_frontier(self) -> None: """Populate the frontier with root tasks (columns with no upstream deps).""" - assert self._graph is not None - for col in self._graph.get_topological_order(): - if self._graph.get_upstream_columns(col): - continue + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + for col in self._graph.get_root_columns(): strategy = self._graph.get_strategy(col) for rg_id, rg_size in self._row_group_sizes.items(): if strategy == GenerationStrategy.CELL_BY_CELL: @@ -135,7 +131,8 @@ def _seed_frontier(self) -> None: def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None) -> None: """Add newly-ready downstream tasks to the frontier.""" - assert self._graph is not None + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) @@ -180,7 +177,8 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None def _reevaluate_batch_tasks(self, row_group: int) -> None: """Check if any batch tasks became ready after a row was dropped.""" - assert self._graph is not None + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index 35e0d4e40..b5bd32cce 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -129,6 +129,10 @@ def get_downstream_columns(self, column: str) -> set[str]: def get_strategy(self, column: str) -> GenerationStrategy: return self._strategies[column] + def get_root_columns(self) -> list[str]: + """Columns with no upstream dependencies, in topological order.""" + return [col for col in self.get_topological_order() if not self._upstream.get(col)] + def split_upstream_by_strategy(self, column: str) -> tuple[list[str], list[str]]: """Split upstream columns into (batch, cell_by_cell) by strategy. Cached.""" cached = self._upstream_by_strategy_cache.get(column) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 225921ac8..21d414733 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -47,22 +47,11 @@ def ready_ctx() -> ReadyTasksFixture: """CompletionTracker wired to the simple 3-column graph with one row group of size 3.""" graph = _build_simple_graph() return ReadyTasksFixture( - tracker=CompletionTracker(graph, [(0, 3)]), + tracker=CompletionTracker.with_graph(graph, [(0, 3)]), dispatched=set(), ) -def test_tracker_requires_row_groups_with_graph() -> None: - graph = _build_simple_graph() - with pytest.raises(ValueError, match="provided together"): - CompletionTracker(graph=graph, row_groups=None) - - -def test_tracker_requires_graph_with_row_groups() -> None: - with pytest.raises(ValueError, match="provided together"): - CompletionTracker(graph=None, row_groups=[(0, 3)]) - - # -- mark_cell_complete / is_complete -------------------------------------- @@ -253,7 +242,7 @@ def test_get_ready_tasks_full_column_ready_when_all_cells_done(ready_ctx: ReadyT def test_get_ready_tasks_multiple_row_groups() -> None: graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 3), (1, 2)]) + tracker = CompletionTracker.with_graph(graph, [(0, 3), (1, 2)]) dispatched: set[Task] = set() tracker.mark_row_range_complete("topic", 0, 3) @@ -291,25 +280,21 @@ def test_mark_row_range_complete_raises_for_cell_by_cell_strategy(ready_ctx: Rea # -- Re-enqueue regression tests ------------------------------------------- -def test_completed_cell_not_reenqueued_after_later_upstream(ready_ctx: ReadyTasksFixture) -> None: - """A → B → C chain: completing C[row=0] then completing another A upstream must not re-enqueue C[row=0].""" +def test_completed_cell_not_reenqueued_after_later_upstream() -> None: + """A → B → C chain: completing C then firing a late upstream event must not re-enqueue C.""" graph = _build_simple_graph() - tracker = CompletionTracker(graph, [(0, 2)]) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) dispatched: set[Task] = set() - # Complete the full pipeline for row 0 + # Complete the full pipeline tracker.mark_row_range_complete("topic", 0, 2) tracker.mark_cell_complete("question", 0, 0) tracker.mark_cell_complete("question", 0, 1) - - # score should now be ready - ready = tracker.get_ready_tasks(dispatched) - score_tasks = [t for t in ready if t.column == "score"] - assert len(score_tasks) == 1 - - # Complete score, then re-complete an upstream cell — score must not reappear tracker.mark_row_range_complete("score", 0, 2) + # Fire a late upstream cell event after score is already done + tracker.mark_cell_complete("question", 0, 0) + ready = tracker.get_ready_tasks(dispatched) score_tasks = [t for t in ready if t.column == "score"] assert len(score_tasks) == 0 @@ -328,7 +313,7 @@ def test_completed_batch_not_reenqueued_by_upstream_cell() -> None: "agg": GenerationStrategy.FULL_COLUMN, } graph = ExecutionGraph.create(configs, strategies) - tracker = CompletionTracker(graph, [(0, 2)]) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) dispatched: set[Task] = set() # Complete seed and gen[0] — agg not ready yet From 15fae2287ccba13d5194795a246c1095dc17fd30 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 6 Mar 2026 12:28:13 -0300 Subject: [PATCH 16/17] rename CellRef to SliceRef A slice naturally represents both a single cell and a full row group, removing the semantic mismatch of CellRef representing batches. --- .../utils/completion_tracker.py | 16 +++++------ .../dataset_builders/utils/execution_graph.py | 14 +++++----- .../dataset_builders/utils/task_model.py | 4 +-- .../utils/test_completion_tracker.py | 28 +++++++++---------- .../utils/test_execution_graph.py | 8 +++--- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 8fdbe7f15..b2da9094d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import CellRef, Task +from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -62,20 +62,20 @@ def mark_row_range_complete(self, column: str, row_group: int, row_group_size: i self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) self._enqueue_downstream(column, row_group, row_index=None) - def is_complete(self, cell: CellRef) -> bool: - return cell.row_index in self._completed.get(cell.row_group, {}).get(cell.column, set()) + def is_complete(self, ref: SliceRef) -> bool: + return ref.row_index in self._completed.get(ref.row_group, {}).get(ref.column, set()) - def is_all_complete(self, cells: list[CellRef]) -> bool: + def is_all_complete(self, cells: list[SliceRef]) -> bool: """Check whether all the given cells are done. A ``row_index`` of ``None`` means the entire batch for that column must have been completed via ``mark_row_range_complete``. """ - for cell in cells: - if cell.row_index is None: - if cell.column not in self._batch_complete.get(cell.row_group, set()): + for ref in cells: + if ref.row_index is None: + if ref.column not in self._batch_complete.get(ref.row_group, set()): return False - elif not self.is_complete(cell): + elif not self.is_complete(ref): return False return True diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index b5bd32cce..29db09c83 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -12,7 +12,7 @@ MultiColumnConfig, ) from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.task_model import CellRef +from data_designer.engine.dataset_builders.utils.task_model import SliceRef class ExecutionGraph: @@ -229,22 +229,22 @@ def compute_cell_dependencies( row_group: int, row_index: int | None, row_group_size: int, - ) -> list[CellRef]: + ) -> list[SliceRef]: """Derive cell-level deps on demand from column-level DAG + strategy. - Returns a list of ``CellRef`` that must be complete before this task can run. + Returns a list of ``SliceRef`` that must be complete before this task can run. """ - deps: list[CellRef] = [] + deps: list[SliceRef] = [] for up_col in self.get_upstream_columns(column): up_strategy = self._strategies[up_col] if up_strategy == GenerationStrategy.CELL_BY_CELL: if row_index is not None: - deps.append(CellRef(up_col, row_group, row_index)) + deps.append(SliceRef(up_col, row_group, row_index)) else: for ri in range(row_group_size): - deps.append(CellRef(up_col, row_group, ri)) + deps.append(SliceRef(up_col, row_group, ri)) else: - deps.append(CellRef(up_col, row_group, None)) + deps.append(SliceRef(up_col, row_group, None)) return deps def to_mermaid(self) -> str: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py index a63b4c7fa..574c594c1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py @@ -8,8 +8,8 @@ @dataclass(frozen=True, order=True) -class CellRef: - """Reference to a cell (or batch when row_index is None) in the dataset grid.""" +class SliceRef: + """Reference to a slice of the execution grid: a single cell or a full row group.""" column: str row_group: int diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 21d414733..3a0463044 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -16,7 +16,7 @@ from data_designer.config.sampler_params import SamplerType from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import CellRef, Task +from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task MODEL_ALIAS = "stub" @@ -59,20 +59,20 @@ def test_mark_and_check_complete() -> None: tracker = CompletionTracker() tracker.mark_cell_complete("col_a", row_group=0, row_index=0) - assert tracker.is_complete(CellRef("col_a", 0, 0)) - assert not tracker.is_complete(CellRef("col_a", 0, 1)) - assert not tracker.is_complete(CellRef("col_a", 1, 0)) - assert not tracker.is_complete(CellRef("col_b", 0, 0)) + assert tracker.is_complete(SliceRef("col_a", 0, 0)) + assert not tracker.is_complete(SliceRef("col_a", 0, 1)) + assert not tracker.is_complete(SliceRef("col_a", 1, 0)) + assert not tracker.is_complete(SliceRef("col_b", 0, 0)) def test_mark_row_range_complete() -> None: tracker = CompletionTracker() tracker.mark_row_range_complete("col_a", row_group=0, row_group_size=3) - assert tracker.is_complete(CellRef("col_a", 0, 0)) - assert tracker.is_complete(CellRef("col_a", 0, 1)) - assert tracker.is_complete(CellRef("col_a", 0, 2)) - assert not tracker.is_complete(CellRef("col_a", 0, 3)) + assert tracker.is_complete(SliceRef("col_a", 0, 0)) + assert tracker.is_complete(SliceRef("col_a", 0, 1)) + assert tracker.is_complete(SliceRef("col_a", 0, 2)) + assert not tracker.is_complete(SliceRef("col_a", 0, 3)) def test_mark_row_range_complete_raises_on_size_mismatch(ready_ctx: ReadyTasksFixture) -> None: @@ -93,15 +93,15 @@ def test_all_complete_cell_level() -> None: tracker.mark_cell_complete("col_a", 0, 0) tracker.mark_cell_complete("col_a", 0, 1) - assert tracker.is_all_complete([CellRef("col_a", 0, 0), CellRef("col_a", 0, 1)]) - assert not tracker.is_all_complete([CellRef("col_a", 0, 0), CellRef("col_a", 0, 2)]) + assert tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 1)]) + assert not tracker.is_all_complete([SliceRef("col_a", 0, 0), SliceRef("col_a", 0, 2)]) def test_all_complete_batch_level() -> None: tracker = CompletionTracker() tracker.mark_row_range_complete("col_a", 0, 3) - assert tracker.is_all_complete([CellRef("col_a", 0, None)]) + assert tracker.is_all_complete([SliceRef("col_a", 0, None)]) def test_all_complete_batch_single_cell_not_sufficient() -> None: @@ -109,12 +109,12 @@ def test_all_complete_batch_single_cell_not_sufficient() -> None: tracker = CompletionTracker() tracker.mark_cell_complete("col_a", 0, 0) - assert not tracker.is_all_complete([CellRef("col_a", 0, None)]) + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) def test_all_complete_batch_not_present() -> None: tracker = CompletionTracker() - assert not tracker.is_all_complete([CellRef("col_a", 0, None)]) + assert not tracker.is_all_complete([SliceRef("col_a", 0, None)]) def test_all_complete_empty_list() -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index 3433de7ba..f8b54600c 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -21,7 +21,7 @@ from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import CellRef +from data_designer.engine.dataset_builders.utils.task_model import SliceRef MODEL_ALIAS = "stub-model-alias" @@ -278,13 +278,13 @@ def test_cell_deps_cell_by_cell_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column); answer depends on question (cell-by-cell).""" # answer[rg=0, row=2] should depend on question[rg=0, row=2] deps = simple_graph.compute_cell_dependencies("answer", row_group=0, row_index=2, row_group_size=5) - assert deps == [CellRef("question", 0, 2)] + assert deps == [SliceRef("question", 0, 2)] def test_cell_deps_full_column_upstream(simple_graph: ExecutionGraph) -> None: """question depends on topic (full-column).""" deps = simple_graph.compute_cell_dependencies("question", row_group=0, row_index=1, row_group_size=5) - assert deps == [CellRef("topic", 0, None)] + assert deps == [SliceRef("topic", 0, None)] def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: @@ -296,7 +296,7 @@ def test_cell_deps_no_upstream(simple_graph: ExecutionGraph) -> None: def test_cell_deps_full_column_downstream_of_cell_by_cell(simple_graph: ExecutionGraph) -> None: """score (full-column) depends on answer (cell-by-cell) → needs ALL rows.""" deps = simple_graph.compute_cell_dependencies("score", row_group=0, row_index=None, row_group_size=3) - assert sorted(deps) == [CellRef("answer", 0, 0), CellRef("answer", 0, 1), CellRef("answer", 0, 2)] + assert sorted(deps) == [SliceRef("answer", 0, 0), SliceRef("answer", 0, 1), SliceRef("answer", 0, 2)] # -- Mermaid output ---------------------------------------------------------- From d0c6bf0201728043d0f568c93ee542f42e7da112 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 6 Mar 2026 12:30:42 -0300 Subject: [PATCH 17/17] add missing tests for drop_row unblock, buffer_size, and duplicate column --- .../utils/test_completion_tracker.py | 14 ++++++++++++++ .../dataset_builders/utils/test_execution_graph.py | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py index 3a0463044..b0e9f8024 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py @@ -216,6 +216,20 @@ def test_get_ready_tasks_skips_dropped_rows(ready_ctx: ReadyTasksFixture) -> Non assert {t.row_index for t in question_tasks} == {0, 2} +def test_drop_row_unblocks_full_column_downstream(ready_ctx: ReadyTasksFixture) -> None: + """Dropping the last incomplete CELL_BY_CELL row should make downstream FULL_COLUMN ready.""" + ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) + ready_ctx.tracker.mark_cell_complete("question", 0, 0) + ready_ctx.tracker.mark_cell_complete("question", 0, 1) + # question[2] never completes -- drop it instead + ready_ctx.tracker.drop_row(0, 2) + + ready = ready_ctx.tracker.get_ready_tasks(ready_ctx.dispatched) + score_tasks = [t for t in ready if t.column == "score"] + assert len(score_tasks) == 1 + assert score_tasks[0].task_type == "batch" + + def test_get_ready_tasks_full_column_waits_for_all_cells(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) ready_ctx.tracker.mark_cell_complete("question", 0, 0) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index f8b54600c..9d2fa69c4 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -271,6 +271,19 @@ def test_task_count_exact_divisor(simple_graph: ExecutionGraph) -> None: assert counts["question"] == 9 +@pytest.mark.parametrize("buffer_size", [0, -1]) +def test_task_count_invalid_buffer_size_raises(simple_graph: ExecutionGraph, buffer_size: int) -> None: + with pytest.raises(ValueError, match="buffer_size"): + simple_graph.compute_task_count(num_records=10, buffer_size=buffer_size) + + +def test_add_column_duplicate_raises() -> None: + graph = ExecutionGraph() + graph.add_column("col_a", GenerationStrategy.CELL_BY_CELL) + with pytest.raises(ValueError, match="already registered"): + graph.add_column("col_a", GenerationStrategy.FULL_COLUMN) + + # -- Cell dependencies ------------------------------------------------------