Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d6b3022
feat: add ExecutionGraph, CompletionTracker, and Task model for async…
andreatgretel Feb 26, 2026
969a87f
refactor: extract readiness helpers and cache topological order
andreatgretel Feb 26, 2026
0aa7d54
refactor: address PR review feedback
andreatgretel Feb 27, 2026
d0d4695
refactor: address remaining PR review feedback
andreatgretel Feb 27, 2026
c30abb4
refactor: event-driven frontier for CompletionTracker
andreatgretel Mar 2, 2026
638c878
refactor: extract ready_ctx fixture in completion tracker tests
andreatgretel Mar 3, 2026
b08cb3d
fix: validate tracker args and resolve side-effect name collisions
andreatgretel Mar 3, 2026
eba61bb
Merge branch 'main' into andreatgretel/feat/async-generators-and-task…
andreatgretel Mar 4, 2026
060b933
Merge branch 'main' into andreatgretel/feat/async-generators-and-task…
andreatgretel Mar 5, 2026
ebb9428
refactor: address PR review feedback — naming, CellRef, batch semantics
andreatgretel Mar 4, 2026
82b8351
fix: prevent completed tasks from re-entering the frontier
andreatgretel Mar 5, 2026
3b539b6
harden completion tracker and execution graph APIs
andreatgretel Mar 5, 2026
03741df
address second round of review feedback
andreatgretel Mar 5, 2026
a8cf010
Merge branch 'main' into andreatgretel/feat/async-generators-and-task…
andreatgretel Mar 5, 2026
f4b10c1
fix AGENTS.md compliance violations
andreatgretel Mar 5, 2026
15fef7f
address review feedback: CellRef dataclass, batch done-guard, is_comp…
andreatgretel Mar 5, 2026
5d599e8
fix test to use cell_by_cell column for row-group validation test
andreatgretel Mar 5, 2026
db5db2a
address review feedback: constructor, assertions, root columns, test fix
andreatgretel Mar 6, 2026
15fae22
rename CellRef to SliceRef
andreatgretel Mar 6, 2026
d0c6bf0
add missing tests for drop_row unblock, buffer_size, and duplicate co…
andreatgretel Mar 6, 2026
7dd6f89
Merge branch 'main' into andreatgretel/feat/async-generators-and-task…
andreatgretel Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Comment thread
andreatgretel marked this conversation as resolved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING

from data_designer.config.column_configs import GenerationStrategy
from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex, Task

if TYPE_CHECKING:
from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph


class CompletionTracker:
"""Tracks which (column, row_group, row_index) tuples are done.
Copy link
Copy Markdown
Contributor

@johnnygreco johnnygreco Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tuple (column, row_group, row_index) is a cell, right? Sorry, I have left too many comments about this, but it feels clunky to have to carry the tuple around everywhere 😅

Can this be indexed / framed as cell_{row_group}? Actually, this makes me realize I'm not sure what the range of row_index is. Is it the actual dataset range, so (i, j) = (row_index, column) in the dataset? Or are we resetting the range for each row group?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered about that too. May be keep track of both?! Somewhere I made a suggestion to use named tuples or a dataclass instead carrying a tuple around.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just track the local indices, we should be able to resolve global indices via a property, for example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

introduced CellRef NamedTuple to replace the raw tuples. row indices are local to their row group (0-based), so for a row group of size 3, indices are 0, 1, 2. added a note in the class docstring

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super duper nit: I'm walking back on my comment of using NamedTuple. I don't think we use it yet in this project. I think a lightweight dataclass can serve the same purpose?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted CellRef to @dataclass(frozen=True, order=True). also took the opportunity to make is_complete accept a CellRef directly instead of 3 positional args — cleans up is_all_complete nicely.


All access is from the single asyncio event loop thread — no locks needed.
Comment thread
andreatgretel marked this conversation as resolved.
Outdated
Row indices are local to their row group (0-based).

When *graph* and *row_groups* are provided, an event-driven frontier is
maintained so that ``get_ready_tasks`` returns in O(frontier) instead of
scanning all columns × rows × row groups.
"""

def __init__(
self,
graph: ExecutionGraph | None = None,
row_groups: list[tuple[RowGroupIndex, int]] | None = None,
) -> None:
if (graph is None) != (row_groups is None):
raise ValueError("`graph` and `row_groups` must be provided together.")
Comment thread
andreatgretel marked this conversation as resolved.
Outdated

# row_group → column → set of completed local row indices
self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set))
# row_group → set of dropped row indices
self._dropped: dict[RowGroupIndex, set[RowIndex]] = defaultdict(set)

self._graph = graph
self._row_group_sizes: dict[RowGroupIndex, int] = {}
self._frontier: set[Task] = set()

if graph is not None and row_groups is not None:
self._row_group_sizes = {rg_id: size for rg_id, size in row_groups}
self._seed_frontier()

def _seed_frontier(self) -> None:
"""Populate the frontier with root tasks (columns with no upstream deps)."""
assert self._graph is not None
Comment thread
andreatgretel marked this conversation as resolved.
Outdated
for col in self._graph.topological_order():
if self._graph.upstream(col):
continue
strategy = self._graph.strategy(col)
for rg_id, rg_size in self._row_group_sizes.items():
if strategy == GenerationStrategy.CELL_BY_CELL:
for ri in range(rg_size):
self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell"))
else:
self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch"))

def mark_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> None:
self._completed[row_group][column].add(row_index)
Comment thread
johnnygreco marked this conversation as resolved.
Outdated
if self._graph is not None:
self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell"))
self._enqueue_downstream(column, row_group, row_index=row_index)

def mark_batch_complete(self, column: ColumnName, row_group: RowGroupIndex, row_group_size: int) -> None:
Comment thread
johnnygreco marked this conversation as resolved.
Outdated
self._completed[row_group][column] = set(range(row_group_size))
if self._graph is not None:
self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch"))
self._enqueue_downstream(column, row_group, row_index=None)
Comment thread
andreatgretel marked this conversation as resolved.
Outdated
Comment thread
andreatgretel marked this conversation as resolved.
Outdated
Comment thread
andreatgretel marked this conversation as resolved.
Outdated

def _enqueue_downstream(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex | None) -> None:
"""Add newly-ready downstream tasks to the frontier."""
assert self._graph is not None
rg_completed = self._completed.get(row_group, {})
rg_dropped = self._dropped.get(row_group, set())
rg_size = self._row_group_sizes[row_group]

for down in self._graph.downstream(column):
batch_ups, cell_ups = self._graph.upstream_by_strategy(down)

# All batch upstreams must be present in completed dict
if any(up not in rg_completed for up in batch_ups):
continue
Comment thread
andreatgretel marked this conversation as resolved.
Outdated
Comment thread
andreatgretel marked this conversation as resolved.

down_strategy = self._graph.strategy(down)

if down_strategy == GenerationStrategy.CELL_BY_CELL:
cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups]
if row_index is not None:
# Cell completion: only check the same row
if row_index not in rg_dropped and all(row_index in s for s in cell_up_completed):
task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell")
self._frontier.add(task)
Comment thread
andreatgretel marked this conversation as resolved.
else:
# Batch completion: check all non-dropped, non-complete rows
down_completed = rg_completed.get(down, set())
Comment thread
johnnygreco marked this conversation as resolved.
for ri in range(rg_size):
if ri in rg_dropped or ri in down_completed:
continue
if all(ri in s for s in cell_up_completed):
task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell")
self._frontier.add(task)
else:
# FULL_COLUMN downstream: ready when all cell upstreams are fully complete
if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped):
task = Task(column=down, row_group=row_group, row_index=None, task_type="batch")
self._frontier.add(task)
Comment thread
andreatgretel marked this conversation as resolved.

def is_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> bool:
return row_index in self._completed.get(row_group, {}).get(column, set())

def is_all_complete(self, cells: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]) -> bool:
"""Check whether all the given (column, row_group, row_index) tuples are done.
Comment thread
johnnygreco marked this conversation as resolved.
Outdated

A ``row_index`` of ``None`` means the entire batch for that column must
be complete (i.e., that column key must exist in the row group's dict).
Comment thread
johnnygreco marked this conversation as resolved.
Outdated
"""
for col, rg, ri in cells:
if ri is None:
if col not in self._completed.get(rg, {}):
return False
Comment thread
johnnygreco marked this conversation as resolved.
Outdated
elif not self.is_complete(col, rg, ri):
return False
return True

def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None:
self._dropped[row_group].add(row_index)
if self._graph is not None:
# Remove cell tasks for this row from the frontier
for col in self._graph.columns:
self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell"))
# Dropping a row may unblock batch downstream tasks
self._reevaluate_batch_tasks(row_group)

def _reevaluate_batch_tasks(self, row_group: RowGroupIndex) -> None:
"""Check if any batch tasks became ready after a row was dropped."""
assert self._graph is not None
rg_completed = self._completed.get(row_group, {})
rg_dropped = self._dropped.get(row_group, set())
rg_size = self._row_group_sizes[row_group]

for col in self._graph.topological_order():
if self._graph.strategy(col) != GenerationStrategy.FULL_COLUMN:
continue
if col in rg_completed:
continue
batch_ups, cell_ups = self._graph.upstream_by_strategy(col)
if any(up not in rg_completed for up in batch_ups):
continue
if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped):
task = Task(column=col, row_group=row_group, row_index=None, task_type="batch")
self._frontier.add(task)

def is_dropped(self, row_group: RowGroupIndex, row_index: RowIndex) -> bool:
return row_index in self._dropped.get(row_group, set())

def is_row_group_complete(
self,
row_group: RowGroupIndex,
row_group_size: int,
all_columns: list[ColumnName],
) -> bool:
"""All non-dropped rows have all columns done."""
dropped = self._dropped.get(row_group, set())
completed = self._completed.get(row_group, {})
for ri in range(row_group_size):
if ri in dropped:
continue
for col in all_columns:
if ri not in completed.get(col, set()):
return False
return True

def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]:
"""Return all currently dispatchable tasks from the frontier.

Excludes already-dispatched/in-flight tasks.
"""
return [t for t in self._frontier if t not in dispatched]

def _are_cell_ups_complete(
self,
cell_ups: list[ColumnName],
rg_completed: dict[ColumnName, set[RowIndex]],
rg_size: int,
rg_dropped: set[RowIndex],
) -> bool:
"""Check all non-dropped rows are complete for each cell-by-cell upstream column."""
for up in cell_ups:
up_completed = rg_completed.get(up, set())
for ri in range(rg_size):
if ri not in rg_dropped and ri not in up_completed:
return False
return True
Loading