-
Notifications
You must be signed in to change notification settings - Fork 159
feat: add ExecutionGraph, CompletionTracker, and Task model for async scheduler #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d6b3022
969a87f
0aa7d54
d0d4695
c30abb4
638c878
b08cb3d
eba61bb
060b933
ebb9428
82b8351
3b539b6
03741df
a8cf010
f4b10c1
15fef7f
5d599e8
db5db2a
15fae22
d0c6bf0
7dd6f89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections import defaultdict | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from data_designer.config.column_configs import GenerationStrategy | ||
| from data_designer.engine.dataset_builders.utils.task_model import ColumnName, RowGroupIndex, RowIndex, Task | ||
|
|
||
| if TYPE_CHECKING: | ||
| from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph | ||
|
|
||
|
|
||
| class CompletionTracker: | ||
| """Tracks which (column, row_group, row_index) tuples are done. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tuple (column, row_group, row_index) is a cell, right? Sorry, I have left too many comments about this, but it feels clunky to have to carry the tuple around everywhere 😅 Can this be indexed / framed as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wondered about that too. May be keep track of both?! Somewhere I made a suggestion to use named tuples or a dataclass instead carrying a tuple around.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we just track the local indices, we should be able to resolve global indices via a property, for example.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. introduced
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super duper nit: I'm walking back on my comment of using NamedTuple. I don't think we use it yet in this project. I think a lightweight dataclass can serve the same purpose?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. converted |
||
|
|
||
| All access is from the single asyncio event loop thread — no locks needed. | ||
|
andreatgretel marked this conversation as resolved.
Outdated
|
||
| Row indices are local to their row group (0-based). | ||
|
|
||
| When *graph* and *row_groups* are provided, an event-driven frontier is | ||
| maintained so that ``get_ready_tasks`` returns in O(frontier) instead of | ||
| scanning all columns × rows × row groups. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| graph: ExecutionGraph | None = None, | ||
| row_groups: list[tuple[RowGroupIndex, int]] | None = None, | ||
| ) -> None: | ||
| if (graph is None) != (row_groups is None): | ||
| raise ValueError("`graph` and `row_groups` must be provided together.") | ||
|
andreatgretel marked this conversation as resolved.
Outdated
|
||
|
|
||
| # row_group → column → set of completed local row indices | ||
| self._completed: dict[RowGroupIndex, dict[ColumnName, set[RowIndex]]] = defaultdict(lambda: defaultdict(set)) | ||
| # row_group → set of dropped row indices | ||
| self._dropped: dict[RowGroupIndex, set[RowIndex]] = defaultdict(set) | ||
|
|
||
| self._graph = graph | ||
| self._row_group_sizes: dict[RowGroupIndex, int] = {} | ||
| self._frontier: set[Task] = set() | ||
|
|
||
| if graph is not None and row_groups is not None: | ||
| self._row_group_sizes = {rg_id: size for rg_id, size in row_groups} | ||
| self._seed_frontier() | ||
|
|
||
| def _seed_frontier(self) -> None: | ||
| """Populate the frontier with root tasks (columns with no upstream deps).""" | ||
| assert self._graph is not None | ||
|
andreatgretel marked this conversation as resolved.
Outdated
|
||
| for col in self._graph.topological_order(): | ||
| if self._graph.upstream(col): | ||
| continue | ||
| strategy = self._graph.strategy(col) | ||
| for rg_id, rg_size in self._row_group_sizes.items(): | ||
| if strategy == GenerationStrategy.CELL_BY_CELL: | ||
| for ri in range(rg_size): | ||
| self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) | ||
| else: | ||
| self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) | ||
|
|
||
| def mark_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> None: | ||
| self._completed[row_group][column].add(row_index) | ||
|
johnnygreco marked this conversation as resolved.
Outdated
|
||
| if self._graph is not None: | ||
| self._frontier.discard(Task(column=column, row_group=row_group, row_index=row_index, task_type="cell")) | ||
| self._enqueue_downstream(column, row_group, row_index=row_index) | ||
|
|
||
| def mark_batch_complete(self, column: ColumnName, row_group: RowGroupIndex, row_group_size: int) -> None: | ||
|
johnnygreco marked this conversation as resolved.
Outdated
|
||
| self._completed[row_group][column] = set(range(row_group_size)) | ||
| if self._graph is not None: | ||
| self._frontier.discard(Task(column=column, row_group=row_group, row_index=None, task_type="batch")) | ||
| self._enqueue_downstream(column, row_group, row_index=None) | ||
|
andreatgretel marked this conversation as resolved.
Outdated
andreatgretel marked this conversation as resolved.
Outdated
andreatgretel marked this conversation as resolved.
Outdated
|
||
|
|
||
| def _enqueue_downstream(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex | None) -> None: | ||
| """Add newly-ready downstream tasks to the frontier.""" | ||
| assert self._graph is not None | ||
| rg_completed = self._completed.get(row_group, {}) | ||
| rg_dropped = self._dropped.get(row_group, set()) | ||
| rg_size = self._row_group_sizes[row_group] | ||
|
|
||
| for down in self._graph.downstream(column): | ||
| batch_ups, cell_ups = self._graph.upstream_by_strategy(down) | ||
|
|
||
| # All batch upstreams must be present in completed dict | ||
| if any(up not in rg_completed for up in batch_ups): | ||
| continue | ||
|
andreatgretel marked this conversation as resolved.
Outdated
andreatgretel marked this conversation as resolved.
|
||
|
|
||
| down_strategy = self._graph.strategy(down) | ||
|
|
||
| if down_strategy == GenerationStrategy.CELL_BY_CELL: | ||
| cell_up_completed = [rg_completed.get(up, set()) for up in cell_ups] | ||
| if row_index is not None: | ||
| # Cell completion: only check the same row | ||
| if row_index not in rg_dropped and all(row_index in s for s in cell_up_completed): | ||
| task = Task(column=down, row_group=row_group, row_index=row_index, task_type="cell") | ||
| self._frontier.add(task) | ||
|
andreatgretel marked this conversation as resolved.
|
||
| else: | ||
| # Batch completion: check all non-dropped, non-complete rows | ||
| down_completed = rg_completed.get(down, set()) | ||
|
johnnygreco marked this conversation as resolved.
|
||
| for ri in range(rg_size): | ||
| if ri in rg_dropped or ri in down_completed: | ||
| continue | ||
| if all(ri in s for s in cell_up_completed): | ||
| task = Task(column=down, row_group=row_group, row_index=ri, task_type="cell") | ||
| self._frontier.add(task) | ||
| else: | ||
| # FULL_COLUMN downstream: ready when all cell upstreams are fully complete | ||
| if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): | ||
| task = Task(column=down, row_group=row_group, row_index=None, task_type="batch") | ||
| self._frontier.add(task) | ||
|
andreatgretel marked this conversation as resolved.
|
||
|
|
||
| def is_complete(self, column: ColumnName, row_group: RowGroupIndex, row_index: RowIndex) -> bool: | ||
| return row_index in self._completed.get(row_group, {}).get(column, set()) | ||
|
|
||
| def is_all_complete(self, cells: list[tuple[ColumnName, RowGroupIndex, RowIndex | None]]) -> bool: | ||
| """Check whether all the given (column, row_group, row_index) tuples are done. | ||
|
johnnygreco marked this conversation as resolved.
Outdated
|
||
|
|
||
| A ``row_index`` of ``None`` means the entire batch for that column must | ||
| be complete (i.e., that column key must exist in the row group's dict). | ||
|
johnnygreco marked this conversation as resolved.
Outdated
|
||
| """ | ||
| for col, rg, ri in cells: | ||
| if ri is None: | ||
| if col not in self._completed.get(rg, {}): | ||
| return False | ||
|
johnnygreco marked this conversation as resolved.
Outdated
|
||
| elif not self.is_complete(col, rg, ri): | ||
| return False | ||
| return True | ||
|
|
||
| def drop_row(self, row_group: RowGroupIndex, row_index: RowIndex) -> None: | ||
| self._dropped[row_group].add(row_index) | ||
| if self._graph is not None: | ||
| # Remove cell tasks for this row from the frontier | ||
| for col in self._graph.columns: | ||
| self._frontier.discard(Task(column=col, row_group=row_group, row_index=row_index, task_type="cell")) | ||
| # Dropping a row may unblock batch downstream tasks | ||
| self._reevaluate_batch_tasks(row_group) | ||
|
|
||
| def _reevaluate_batch_tasks(self, row_group: RowGroupIndex) -> None: | ||
| """Check if any batch tasks became ready after a row was dropped.""" | ||
| assert self._graph is not None | ||
| rg_completed = self._completed.get(row_group, {}) | ||
| rg_dropped = self._dropped.get(row_group, set()) | ||
| rg_size = self._row_group_sizes[row_group] | ||
|
|
||
| for col in self._graph.topological_order(): | ||
| if self._graph.strategy(col) != GenerationStrategy.FULL_COLUMN: | ||
| continue | ||
| if col in rg_completed: | ||
| continue | ||
| batch_ups, cell_ups = self._graph.upstream_by_strategy(col) | ||
| if any(up not in rg_completed for up in batch_ups): | ||
| continue | ||
| if self._are_cell_ups_complete(cell_ups, rg_completed, rg_size, rg_dropped): | ||
| task = Task(column=col, row_group=row_group, row_index=None, task_type="batch") | ||
| self._frontier.add(task) | ||
|
|
||
| def is_dropped(self, row_group: RowGroupIndex, row_index: RowIndex) -> bool: | ||
| return row_index in self._dropped.get(row_group, set()) | ||
|
|
||
| def is_row_group_complete( | ||
| self, | ||
| row_group: RowGroupIndex, | ||
| row_group_size: int, | ||
| all_columns: list[ColumnName], | ||
| ) -> bool: | ||
| """All non-dropped rows have all columns done.""" | ||
| dropped = self._dropped.get(row_group, set()) | ||
| completed = self._completed.get(row_group, {}) | ||
| for ri in range(row_group_size): | ||
| if ri in dropped: | ||
| continue | ||
| for col in all_columns: | ||
| if ri not in completed.get(col, set()): | ||
| return False | ||
| return True | ||
|
|
||
| def get_ready_tasks(self, dispatched: set[Task]) -> list[Task]: | ||
| """Return all currently dispatchable tasks from the frontier. | ||
|
|
||
| Excludes already-dispatched/in-flight tasks. | ||
| """ | ||
| return [t for t in self._frontier if t not in dispatched] | ||
|
|
||
| def _are_cell_ups_complete( | ||
| self, | ||
| cell_ups: list[ColumnName], | ||
| rg_completed: dict[ColumnName, set[RowIndex]], | ||
| rg_size: int, | ||
| rg_dropped: set[RowIndex], | ||
| ) -> bool: | ||
| """Check all non-dropped rows are complete for each cell-by-cell upstream column.""" | ||
| for up in cell_ups: | ||
| up_completed = rg_completed.get(up, set()) | ||
| for ri in range(rg_size): | ||
| if ri not in rg_dropped and ri not in up_completed: | ||
| return False | ||
| return True | ||
Uh oh!
There was an error while loading. Please reload this page.