diff --git a/AGENTS.md b/AGENTS.md index abe7aaa58..fdb5445f2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -117,7 +117,7 @@ make coverage # Run tests with coverage report - [packages/data-designer/src/data_designer/interface/data_designer.py](packages/data-designer/src/data_designer/interface/data_designer.py) - Main entry point (`DataDesigner` class) - [packages/data-designer-config/src/data_designer/config/config_builder.py](packages/data-designer-config/src/data_designer/config/config_builder.py) - Configuration API (`DataDesignerConfigBuilder`) - [packages/data-designer-config/src/data_designer/config/__init__.py](packages/data-designer-config/src/data_designer/config/__init__.py) - User-facing config API exports -- [packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py](packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py) - Generation orchestrator +- [packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py](packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py) - Generation orchestrator - [pyproject.toml](pyproject.toml) - Project dependencies and tool configurations - [Makefile](Makefile) - Common development commands diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 98908cc7f..c2e4269ec 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -94,6 +94,11 @@ class RunConfig(ConfigBase): Default is 0. async_trace: If True, collect per-task tracing data when using the async engine (DATA_DESIGNER_ASYNC_ENGINE=1). Has no effect on the sync path. Default is False. + progress_bar: If True, display sticky ANSI progress bars instead of periodic log lines + during generation. Requires a TTY; falls back to log lines in non-TTY environments. + Default is False. + progress_interval: How often (in seconds) the async progress reporter emits a + consolidated log block. Must be > 0. Default is 5.0. throttle: AIMD throttle tuning parameters. See ``ThrottleConfig`` for details. """ @@ -105,6 +110,8 @@ class RunConfig(ConfigBase): max_conversation_restarts: int = Field(default=5, ge=0) max_conversation_correction_steps: int = Field(default=0, ge=0) async_trace: bool = False + progress_bar: bool = False + progress_interval: float = Field(default=5.0, gt=0.0) throttle: ThrottleConfig = Field(default_factory=ThrottleConfig) @model_validator(mode="after") diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py index 8e44c56f1..28b79adf3 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/expression.py @@ -10,6 +10,7 @@ from data_designer.config.column_configs import ExpressionColumnConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError +from data_designer.engine.context import format_row_group_tag from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering from data_designer.engine.processing.utils import deserialize_json_values @@ -21,7 +22,7 @@ class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]): def generate(self, data: pd.DataFrame) -> pd.DataFrame: - logger.info(f"🧩 Generating column `{self.config.name}` from expression") + logger.info(f"🧩 {format_row_group_tag()}Generating column `{self.config.name}` from expression") missing_columns = list(set(self.config.required_columns) - set(data.columns)) if len(missing_columns) > 0: diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py index f4bef1cab..7f327239f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/samplers.py @@ -10,6 +10,7 @@ from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy +from data_designer.engine.context import format_row_group_tag from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig from data_designer.engine.errors import DataDesignerRuntimeError from data_designer.engine.processing.utils import concat_datasets @@ -68,7 +69,8 @@ def _log_person_generation_if_needed(self) -> None: def _prepare_for_generation(self, num_records: int) -> SamplingDatasetGenerator: logger.info( - f"🎲 Preparing samplers to generate {num_records} records across {len(self.config.columns)} columns" + f"🎲 {format_row_group_tag()}Preparing samplers to generate" + f" {num_records} records across {len(self.config.columns)} columns" ) self._log_person_generation_if_needed() return self._create_sampling_dataset_generator() diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py index 3e6cdb76f..7a3909889 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -10,6 +10,7 @@ from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy from data_designer.engine.column_generators.utils.errors import SeedDatasetError +from data_designer.engine.context import format_row_group_tag from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig from data_designer.engine.processing.utils import concat_datasets from data_designer.logging import LOG_INDENT @@ -89,7 +90,7 @@ def _reset_batch_reader(self, num_records: int) -> None: ) def _sample_records(self, num_records: int) -> pd.DataFrame: - logger.info(f"🌱 Sampling {num_records} records from seed dataset") + logger.info(f"🌱 {format_row_group_tag()}Sampling {num_records} records from seed dataset") logger.info(f"{LOG_INDENT}seed dataset size: {self._seed_dataset_size} records") logger.info(f"{LOG_INDENT}sampling strategy: {self.config.sampling_strategy}") if self._index_range is not None: diff --git a/packages/data-designer-engine/src/data_designer/engine/context.py b/packages/data-designer-engine/src/data_designer/engine/context.py new file mode 100644 index 000000000..500b6bb51 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/context.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from contextvars import ContextVar + +# Set by the async scheduler before executing each task. +# Value: (current_rg_index, total_rg_count) or None. +current_row_group: ContextVar[tuple[int, int] | None] = ContextVar("current_row_group", default=None) + + +def format_row_group_tag() -> str: + """Return a '(x/X) ' prefix if a row group context is active, else ''.""" + rg = current_row_group.get() + if rg is None: + return "" + current, total = rg[0] + 1, rg[1] + width = len(str(total)) + return f"({current:0{width}d}/{total}) " diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index d41243bc5..9795afb70 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -10,12 +10,19 @@ from collections import deque from collections.abc import Coroutine from dataclasses import dataclass -from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.context import current_row_group +from data_designer.engine.dataset_builders.utils.async_progress_reporter import ( + DEFAULT_REPORT_INTERVAL, + AsyncProgressReporter, +) from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker -from data_designer.engine.dataset_builders.utils.task_model import Task, TaskTrace +from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker +from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar +from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace from data_designer.engine.models.errors import ( ModelAPIConnectionError, ModelInternalServerError, @@ -78,14 +85,17 @@ def __init__( max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE, max_llm_wait_tasks: int = DEFAULT_TASK_POOL_SIZE, salvage_max_rounds: int = 2, - on_row_group_complete: Callable[[int], None] | None = None, - on_checkpoint_complete: Callable[[Path | str], None] | None = None, + on_finalize_row_group: Callable[[int], None] | None = None, on_seeds_complete: Callable[[int, int], None] | None = None, on_before_checkpoint: Callable[[int, int], None] | None = None, shutdown_error_rate: float = 0.5, shutdown_error_window: int = 10, disable_early_shutdown: bool = False, trace: bool = False, + num_records: int = 0, + buffer_size: int = 0, + progress_interval: float | None = None, + progress_bar: bool = False, ) -> None: self._generators = generators self._graph = graph @@ -104,8 +114,7 @@ def __init__( self._worker_tasks: set[asyncio.Task] = set() self._wake_event = asyncio.Event() self._salvage_max_rounds = salvage_max_rounds - self._on_row_group_complete = on_row_group_complete - self._on_checkpoint_complete = on_checkpoint_complete + self._on_finalize_row_group = on_finalize_row_group self._on_seeds_complete = on_seeds_complete self._on_before_checkpoint = on_before_checkpoint @@ -147,6 +156,40 @@ def __init__( # Pre-compute seed columns (graph is static) self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c)) + # Per-column progress tracking (cell-by-cell only; full-column tasks are instant) + self._progress_bar = StickyProgressBar() if progress_bar else None + self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval) + + def _setup_async_progress_reporter( + self, + num_records: int, + buffer_size: int, + progress_interval: float | None, + ) -> AsyncProgressReporter | None: + if num_records <= 0 or buffer_size <= 0: + return None + + task_counts = self._graph.compute_task_count(num_records, buffer_size) + trackers: dict[str, ProgressTracker] = {} + for col in self._graph.columns: + if self._graph.get_strategy(col) != GenerationStrategy.CELL_BY_CELL: + continue + trackers[col] = ProgressTracker( + total_records=task_counts[col], + label=f"column '{col}'", + quiet=True, + ) + + if not trackers: + return None + + interval = progress_interval if progress_interval is not None else DEFAULT_REPORT_INTERVAL + return AsyncProgressReporter( + trackers, + report_interval=interval, + progress_bar=self._progress_bar, + ) + def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: """Create a tracked worker task that auto-removes itself on completion.""" task = asyncio.create_task(coro) @@ -187,36 +230,42 @@ async def run(self) -> None: seed_cols = self._seed_cols has_pre_batch = self._on_seeds_complete is not None - # Launch admission as a background task so it interleaves with dispatch. - admission_task = asyncio.create_task(self._admit_row_groups()) + num_rgs = len(self._row_groups) - try: - # Main dispatch loop - await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns) - - # Cancel admission if still running - if not admission_task.done(): - admission_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await admission_task - - # Phase 3: Salvage rounds for retryable failures - await self._salvage_rounds(seed_cols, has_pre_batch, all_columns) - - if self._rg_states: - incomplete = list(self._rg_states) - logger.error( - f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " - "These row groups were not checkpointed." - ) + with self._progress_bar or contextlib.nullcontext(): + if self._reporter: + self._reporter.log_start(num_row_groups=num_rgs) + + # Launch admission as a background task so it interleaves with dispatch. + admission_task = asyncio.create_task(self._admit_row_groups()) + + try: + # Main dispatch loop + await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns) + + # Cancel admission if still running + if not admission_task.done(): + admission_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await admission_task + + if self._reporter: + self._reporter.log_final() + + if self._rg_states: + incomplete = list(self._rg_states) + logger.error( + f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " + "These row groups were not checkpointed." + ) - except asyncio.CancelledError: - if not admission_task.done(): - admission_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await admission_task - await asyncio.shield(self._cancel_workers()) - raise + except asyncio.CancelledError: + if not admission_task.done(): + admission_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await admission_task + await asyncio.shield(self._cancel_workers()) + raise async def _main_dispatch_loop( self, @@ -228,6 +277,8 @@ async def _main_dispatch_loop( while True: if self._early_shutdown: logger.warning("Early shutdown triggered - error rate exceeded threshold") + if self._deferred: + await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) break @@ -255,16 +306,20 @@ async def _main_dispatch_loop( self._checkpoint_completed_row_groups(all_columns) + # Eagerly salvage any row groups that have only deferred tasks, + # even if other row groups are still in-flight. This frees + # semaphore slots so admission doesn't lose capacity. + if self._deferred: + await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) + # Are we done? all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight if all_done: break - # All admitted RGs finished their non-deferred work but may not be - # "complete" yet (deferred tasks remain for salvage). Exit the main - # loop so salvage rounds can handle them. - if self._all_rgs_admitted and not ready and not self._in_flight: - break + if not ready and not self._in_flight: + if self._all_rgs_admitted: + break if not ready: await self._wake_event.wait() @@ -279,7 +334,7 @@ async def _salvage_rounds( for round_num in range(self._salvage_max_rounds): if not self._deferred: break - logger.info(f"Salvage round {round_num + 1}/{self._salvage_max_rounds}: {len(self._deferred)} tasks") + logger.debug(f"Salvage round {round_num + 1}/{self._salvage_max_rounds}: {len(self._deferred)} tasks") to_retry = self._deferred self._deferred = [] for task in to_retry: @@ -310,6 +365,16 @@ async def _salvage_rounds( self._dispatched.add( Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") ) + # Re-mark sibling columns as dispatched to mirror _dispatch_seeds + # and prevent _drain_frontier from re-dispatching them. + for sibling in self._instance_to_columns.get(gid, []): + if sibling != task.column: + self._dispatched.add( + Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch") + ) + self._dispatched.add( + Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch") + ) self._in_flight.add(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count += 1 @@ -341,11 +406,62 @@ async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count += 1 self._spawn_worker(self._execute_task(task)) - if not self._in_flight: + if not ready and not self._in_flight: break + if not self._in_flight: + continue self._wake_event.clear() await self._wake_event.wait() + async def _salvage_stalled_row_groups( + self, + seed_cols: frozenset[str], + has_pre_batch: bool, + all_columns: list[str], + ) -> None: + """Salvage row groups whose tasks are all deferred (0 in-flight). + + Retries deferred tasks inline so the row groups can be checkpointed + and their semaphore slots freed, preventing deadlock when admission + is blocked. + """ + stalled_rgs = { + t.row_group + for t in self._deferred + if (s := self._rg_states.get(t.row_group)) is not None and s.in_flight_count == 0 + } + if not stalled_rgs: + return + + num_rgs = len(self._row_groups) + width = len(str(num_rgs)) + for rg_id in sorted(stalled_rgs): + rg_deferred = [t for t in self._deferred if t.row_group == rg_id] + logger.info(f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Salvaging {len(rg_deferred)} deferred task(s)") + + # Partition deferred into stalled (retry now) and other (keep for later). + stalled_deferred = [t for t in self._deferred if t.row_group in stalled_rgs] + other_deferred = [t for t in self._deferred if t.row_group not in stalled_rgs] + self._deferred = stalled_deferred + await self._salvage_rounds(seed_cols, has_pre_batch, all_columns) + # Separate stalled tasks that exhausted retries from any new failures + # that _drain_frontier may have appended for non-stalled row groups. + exhausted = [t for t in self._deferred if t.row_group in stalled_rgs] + newly_deferred = [t for t in self._deferred if t.row_group not in stalled_rgs] + for task in exhausted: + # If the row was already dropped by an earlier task in this loop, + # the skip was already counted; don't also record a failure. + already_dropped = task.row_index is not None and self._tracker.is_dropped(task.row_group, task.row_index) + if not already_dropped and self._reporter: + self._reporter.record_failure(task.column) + if task.row_index is not None: + self._drop_row(task.row_group, task.row_index, exclude_columns={task.column}) + else: + rg_size = self._get_rg_size(task.row_group) + self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column}) + self._checkpoint_completed_row_groups(all_columns) + self._deferred = other_deferred + newly_deferred + def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: """Checkpoint any row groups that reached completion.""" completed = [ @@ -366,28 +482,27 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: f"on_before_checkpoint failed for row group {rg_id}, dropping row group.", exc_info=True, ) - for ri in range(rg_size): - self._tracker.drop_row(rg_id, ri) - if self._buffer_manager: - self._buffer_manager.drop_row(rg_id, ri) + self._drop_row_group(rg_id, rg_size) + if self._buffer_manager: + self._buffer_manager.free_row_group(rg_id) dropped = True - if not dropped and self._buffer_manager is not None: - if self._on_checkpoint_complete is not None: - - def on_complete(final_path: Path | str | None) -> None: - if final_path is not None: - self._on_checkpoint_complete(final_path) - - self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) - else: - self._buffer_manager.checkpoint_row_group(rg_id) - if not dropped and self._on_row_group_complete: - self._on_row_group_complete(rg_id) + # If all rows were dropped (e.g. seed failure), free instead of finalizing + if not dropped and all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)): + if self._buffer_manager: + self._buffer_manager.free_row_group(rg_id) + dropped = True + if not dropped and self._on_finalize_row_group is not None: + self._on_finalize_row_group(rg_id) except Exception: logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) finally: self._rg_semaphore.release() + # Clean up deferred tasks for checkpointed row groups + if completed: + checkpointed = {rg_id for rg_id, _ in completed} + self._deferred = [t for t in self._deferred if t.row_group not in checkpointed] + def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: """Run pre-batch callbacks for row groups whose seeds just completed.""" for rg_id, state in list(self._rg_states.items()): @@ -398,20 +513,56 @@ def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: if self._on_seeds_complete: try: self._on_seeds_complete(rg_id, state.size) + # The callback may drop rows (e.g. pre-batch filtering). + # Record skipped tasks for any newly-dropped rows so + # progress reporting stays accurate. + if self._reporter: + for ri in range(state.size): + if self._tracker.is_dropped(rg_id, ri): + self._record_skipped_tasks_for_row(rg_id, ri) except Exception: logger.warning( f"Pre-batch processor failed for row group {rg_id}, skipping.", exc_info=True, ) - for ri in range(state.size): - self._tracker.drop_row(rg_id, ri) - if self._buffer_manager: - self._buffer_manager.drop_row(rg_id, ri) + self._drop_row_group(rg_id, state.size) + + def _drop_row(self, row_group: int, row_index: int, *, exclude_columns: set[str] | None = None) -> None: + if self._tracker.is_dropped(row_group, row_index): + return - def _in_flight_for_rg(self, rg_id: int) -> bool: - """Check if any tasks are in-flight for a given row group.""" - state = self._rg_states.get(rg_id) - return state is not None and state.in_flight_count > 0 + self._record_skipped_tasks_for_row(row_group, row_index, exclude_columns=exclude_columns) + self._tracker.drop_row(row_group, row_index) + if self._buffer_manager: + self._buffer_manager.drop_row(row_group, row_index) + + def _drop_row_group(self, row_group: int, row_group_size: int, *, exclude_columns: set[str] | None = None) -> None: + for row_index in range(row_group_size): + self._drop_row(row_group, row_index, exclude_columns=exclude_columns) + + def _record_skipped_tasks_for_row( + self, + row_group: int, + row_index: int, + *, + exclude_columns: set[str] | None = None, + ) -> None: + if self._reporter is None: + return + + excluded = exclude_columns or set() + in_flight_columns = { + task.column for task in self._in_flight if task.row_group == row_group and task.row_index == row_index + } + + for column in self._graph.columns: + if column in excluded or self._graph.get_strategy(column) != GenerationStrategy.CELL_BY_CELL: + continue + if column in in_flight_columns: + continue + if self._tracker.is_complete(SliceRef(column=column, row_group=row_group, row_index=row_index)): + continue + self._reporter.record_skipped(column) def _check_error_rate(self, *, success: bool) -> None: """Trigger early shutdown if recent error rate exceeds threshold.""" @@ -428,6 +579,11 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" self._rg_states[rg_id].seeds_dispatched = True seed_cols = self._seed_cols + if not seed_cols: + return + num_rgs = len(self._rg_size_map) + width = len(str(num_rgs)) + logger.info(f"🚀 ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records") seen_instances: set[int] = set() for col in seed_cols: @@ -484,6 +640,14 @@ async def _execute_task_inner(self, task: Task) -> None: the submission slot (never reacquired). This prevents cross-key starvation while bounding live coroutines. """ + num_rgs = len(self._row_groups) + token = current_row_group.set((task.row_group, num_rgs)) + try: + await self._execute_task_inner_impl(task) + finally: + current_row_group.reset(token) + + async def _execute_task_inner_impl(self, task: Task) -> None: trace: TaskTrace | None = None if self._trace: trace = TaskTrace.from_task(task) @@ -534,32 +698,31 @@ async def _execute_task_inner(self, task: Task) -> None: self._tracker.mark_cell_complete(col, task.row_group, task.row_index) self._check_error_rate(success=True) + if self._reporter: + self._reporter.record_success(task.column) if self._trace and trace: trace.status = "ok" except Exception as exc: if not isinstance(exc, ModelRateLimitError): self._check_error_rate(success=False) + retryable = self._is_retryable(exc) + if not retryable and self._reporter: + self._reporter.record_failure(task.column) if self._trace and trace: trace.status = "error" trace.error = str(exc) - retryable = self._is_retryable(exc) if retryable: self._deferred.append(task) else: # Non-retryable: drop the affected row(s) if task.row_index is not None: - self._tracker.drop_row(task.row_group, task.row_index) - if self._buffer_manager: - self._buffer_manager.drop_row(task.row_group, task.row_index) + self._drop_row(task.row_group, task.row_index, exclude_columns={task.column}) else: # Batch/from_scratch failure: drop all rows in the row group rg_size = self._get_rg_size(task.row_group) - for ri in range(rg_size): - self._tracker.drop_row(task.row_group, ri) - if self._buffer_manager: - self._buffer_manager.drop_row(task.row_group, ri) + self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column}) logger.warning( f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}" ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py similarity index 85% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py rename to packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index d8e85c0b2..f97054340 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -3,6 +3,7 @@ from __future__ import annotations +import contextlib import functools import logging import os @@ -11,6 +12,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType from data_designer.config.config_builder import BuilderConfig @@ -35,6 +37,7 @@ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner, ProcessorStage from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker +from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor @@ -80,7 +83,7 @@ _CLIENT_VERSION: str = get_library_version() -class ColumnWiseDatasetBuilder: +class DatasetBuilder: def __init__( self, data_designer_config: DataDesignerConfig, @@ -200,17 +203,46 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: self.artifact_storage.set_media_storage_mode(StorageMode.DATAFRAME) generators = self._initialize_generators() - group_id = uuid.uuid4().hex start_time = time.perf_counter() - self.batch_manager.start(num_records=num_records, buffer_size=num_records) - self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id) - dataset = self.batch_manager.get_current_batch(as_dataframe=True) - self.batch_manager.reset() + + if DATA_DESIGNER_ASYNC_ENGINE: + self._validate_async_compatibility() + dataset = self._build_async_preview(generators, num_records) + else: + group_id = uuid.uuid4().hex + self.batch_manager.start(num_records=num_records, buffer_size=num_records) + self._run_batch(generators, batch_mode="preview", save_partial_results=False, group_id=group_id) + dataset = self.batch_manager.get_current_batch(as_dataframe=True) + self.batch_manager.reset() self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time) return dataset + def _build_async_preview(self, generators: list[ColumnGenerator], num_records: int) -> pd.DataFrame: + """Async preview path - single row group, no disk writes, returns in-memory DataFrame.""" + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue preview") + + scheduler, buffer_manager = self._prepare_async_run( + generators, + num_records, + buffer_size=num_records, + run_post_batch_in_scheduler=False, + ) + + loop = ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) + future.result() + + self._task_traces = scheduler.traces + + if not buffer_manager.has_row_group(0): + return lazy.pd.DataFrame() + + dataset = buffer_manager.get_dataframe(0) + buffer_manager.free_row_group(0) + return dataset + def _validate_async_compatibility(self) -> None: """Raise if any column uses allow_resize=True with the async scheduler.""" offending = [config.name for config in self.single_column_configs if getattr(config, "allow_resize", False)] @@ -228,10 +260,69 @@ def _build_async( buffer_size: int, on_batch_complete: Callable[[Path], None] | None = None, ) -> None: - """Async task-queue builder path — dispatches tasks based on dependency readiness.""" + """Async task-queue builder path - dispatches tasks based on dependency readiness.""" logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue builder") - # Build strategy map from generators + settings = self._resource_provider.run_config + trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" + + def finalize_row_group(rg_id: int) -> None: + def on_complete(final_path: Path | str | None) -> None: + if final_path is not None and on_batch_complete: + on_batch_complete(final_path) + + buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) + + scheduler, buffer_manager = self._prepare_async_run( + generators, + num_records, + buffer_size, + on_finalize_row_group=finalize_row_group, + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, + disable_early_shutdown=settings.disable_early_shutdown, + trace=trace_enabled, + ) + + # Telemetry snapshot + group_id = uuid.uuid4().hex + pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() + + # Run on background event loop + loop = ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) + future.result() + + self._task_traces = scheduler.traces + + # Emit telemetry + try: + usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) + self._emit_batch_inference_events("batch", usage_deltas, group_id) + except Exception: + logger.debug("Failed to emit batch telemetry for async run", exc_info=True) + + # Write metadata + buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) + + def _prepare_async_run( + self, + generators: list[ColumnGenerator], + num_records: int, + buffer_size: int, + *, + on_finalize_row_group: Callable[[int], None] | None = None, + run_post_batch_in_scheduler: bool = True, + shutdown_error_rate: float = 0.5, + shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, + trace: bool = False, + ) -> tuple[AsyncTaskScheduler, RowGroupBufferManager]: + """Build a fully-wired scheduler and buffer manager for async generation. + + Shared setup for both build and preview paths. Processor hooks are always + wired when the config has processors, so callers cannot accidentally omit them. + """ strategies: dict[str, GenerationStrategy] = {} gen_map: dict[str, ColumnGenerator] = {} for gen in generators: @@ -245,7 +336,6 @@ def _build_async( graph = ExecutionGraph.create(self._column_configs, strategies) - # Log pre-generation info for all generators for gen in generators: gen.log_pre_generation() @@ -261,36 +351,25 @@ def _build_async( tracker = CompletionTracker.with_graph(graph, row_groups) buffer_manager = RowGroupBufferManager(self.artifact_storage) - settings = self._resource_provider.run_config # Pre-batch processor callback: runs after seed tasks complete for a row group. # If it raises, the scheduler drops all rows in the row group (skips it). def on_seeds_complete(rg_id: int, rg_size: int) -> None: - if not self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH): - return df = buffer_manager.get_dataframe(rg_id) df = self._processor_runner.run_pre_batch_on_df(df) buffer_manager.replace_dataframe(rg_id, df) - # Sync newly-dropped rows to the tracker so downstream cell tasks are skipped for ri in range(rg_size): if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri): tracker.drop_row(rg_id, ri) - # Post-batch processor callback: runs after all columns, before checkpoint. - # rg_id is used as current_batch_number; both are 0-based sequential indices today. + # Post-batch processor callback: runs after all columns, before finalization. def on_before_checkpoint(rg_id: int, rg_size: int) -> None: df = buffer_manager.get_dataframe(rg_id) df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id) buffer_manager.replace_dataframe(rg_id, df) - # Telemetry snapshot - group_id = uuid.uuid4().hex - pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() - - trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" - # Coarse upper bound: sums all registered aliases, not just those used - # in this build. Oversizing is harmless — ThrottleManager enforces + # in this build. Oversizing is harmless - ThrottleManager enforces # the real per-key limit; the semaphore is a memory-safety cap. aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests() @@ -302,35 +381,25 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: buffer_manager=buffer_manager, max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, LLM_WAIT_POOL_MULTIPLIER * aggregate), - on_checkpoint_complete=on_batch_complete, + on_finalize_row_group=on_finalize_row_group, on_seeds_complete=( on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None ), on_before_checkpoint=( - on_before_checkpoint if self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) else None + on_before_checkpoint + if run_post_batch_in_scheduler and self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) + else None ), - shutdown_error_rate=settings.shutdown_error_rate, - shutdown_error_window=settings.shutdown_error_window, - disable_early_shutdown=settings.disable_early_shutdown, - trace=trace_enabled, + shutdown_error_rate=shutdown_error_rate, + shutdown_error_window=shutdown_error_window, + disable_early_shutdown=disable_early_shutdown, + trace=trace, + num_records=num_records, + buffer_size=buffer_size, + progress_interval=self._resource_provider.run_config.progress_interval, + progress_bar=self._resource_provider.run_config.progress_bar, ) - - # Run on background event loop - loop = ensure_async_engine_loop() - future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) - future.result() - - self._task_traces = scheduler.traces - - # Emit telemetry - try: - usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) - self._emit_batch_inference_events("batch", usage_deltas, group_id) - except Exception: - logger.debug("Failed to emit batch telemetry for async run", exc_info=True) - - # Write metadata - buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) + return scheduler, buffer_manager def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame: df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None) @@ -474,7 +543,10 @@ def _run_mcp_tool_check_if_needed(self) -> None: self._resource_provider.mcp_registry.run_health_check(tool_aliases) def _setup_fan_out( - self, generator: ColumnGeneratorWithModelRegistry, max_workers: int + self, + generator: ColumnGeneratorWithModelRegistry, + max_workers: int, + progress_bar: StickyProgressBar | None = None, ) -> tuple[ProgressTracker, dict[str, Any]]: if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( @@ -490,9 +562,12 @@ def _setup_fan_out( else: self._cell_resize_mode = False + label = f"{generator.config.column_type} column '{generator.config.name}'" progress_tracker = ProgressTracker( total_records=self.batch_manager.num_records_batch, - label=f"{generator.config.column_type} column '{generator.config.name}'", + label=label, + progress_bar=progress_bar, + progress_bar_key=generator.config.name, ) progress_tracker.log_start(max_workers) @@ -539,30 +614,34 @@ def _finalize_fan_out(self, progress_tracker: ProgressTracker) -> None: def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") - progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) - executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs) - work_items = [ - ( - generator.agenerate(record), - {"index": i, "column_name": generator.config.name}, - ) - for i, record in self.batch_manager.iter_current_batch() - ] - executor.run(work_items) - self._finalize_fan_out(progress_tracker) + bar = StickyProgressBar() if self._resource_provider.run_config.progress_bar else None + with bar or contextlib.nullcontext(): + progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers, progress_bar=bar) + executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs) + work_items = [ + ( + generator.agenerate(record), + {"index": i, "column_name": generator.config.name}, + ) + for i, record in self.batch_manager.iter_current_batch() + ] + executor.run(work_items) + self._finalize_fan_out(progress_tracker) def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") - progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) - with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor: - for i, record in self.batch_manager.iter_current_batch(): - executor.submit( - lambda record: generator.generate(record), - record, - context={"index": i, "column_name": generator.config.name}, - ) - self._finalize_fan_out(progress_tracker) + bar = StickyProgressBar() if self._resource_provider.run_config.progress_bar else None + with bar or contextlib.nullcontext(): + progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers, progress_bar=bar) + with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor: + for i, record in self.batch_manager.iter_current_batch(): + executor.submit( + lambda record: generator.generate(record), + record, + context={"index": i, "column_name": generator.config.name}, + ) + self._finalize_fan_out(progress_tracker) def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]: def callback(result: dict, *, context: dict | None = None) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py new file mode 100644 index 000000000..c2075c526 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING + +from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker +from data_designer.logging import LOG_INDENT + +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar + +logger = logging.getLogger(__name__) + +DEFAULT_REPORT_INTERVAL = 5.0 + + +class AsyncProgressReporter: + """Consolidated progress reporter for async generation. + + Owns per-column ProgressTracker instances (in quiet mode) and emits + a single grouped log block at most once per ``report_interval`` seconds. + """ + + def __init__( + self, + trackers: dict[str, ProgressTracker], + *, + report_interval: float = DEFAULT_REPORT_INTERVAL, + progress_bar: StickyProgressBar | None = None, + ) -> None: + self._trackers = trackers + self._report_interval = report_interval + self._start_time = time.perf_counter() + self._last_report_time: float = self._start_time + self._last_reported_total: int = -1 + self._bar = progress_bar + if self._bar is not None: + for col, tracker in trackers.items(): + self._bar.add_bar(col, f"column '{col}'", tracker.total_records) + + def log_start(self, num_row_groups: int) -> None: + cols = ", ".join(self._trackers) + total = sum(t.total_records for t in self._trackers.values()) + logger.info( + "⚡️ Async generation: %d column(s) (%s), %d tasks across %d row group(s)", + len(self._trackers), + cols, + total, + num_row_groups, + ) + + def record_success(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_success() + self._maybe_report() + + def record_failure(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_failure() + self._maybe_report() + + def record_skipped(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_skipped() + self._maybe_report() + + def log_final(self) -> None: + if self._bar is not None and self._bar.is_active: + for col in self._trackers: + self._bar.remove_bar(col) + else: + self._emit() + elapsed = time.perf_counter() - self._start_time + snapshots = [tracker.get_snapshot(elapsed) for tracker in self._trackers.values()] + total_ok = sum(snapshot[2] for snapshot in snapshots) + total_fail = sum(snapshot[3] for snapshot in snapshots) + total_skipped = sum(snapshot[4] for snapshot in snapshots) + skipped_suffix = f", {total_skipped} skipped" if total_skipped else "" + logger.info( + "✅ Async generation complete [%.1fs]: %d ok, %d failed%s across %d column(s)", + elapsed, + total_ok, + total_fail, + skipped_suffix, + len(self._trackers), + ) + + def _maybe_report(self) -> None: + if self._bar is not None and self._bar.is_active: + self._update_bar() + return + now = time.perf_counter() + if now - self._last_report_time < self._report_interval: + return + self._last_report_time = now + self._emit() + + def _update_bar(self) -> None: + elapsed = time.perf_counter() - self._start_time + for col, tracker in self._trackers.items(): + completed, _total, success, failed, _skipped, _pct, _rate, _emoji = tracker.get_snapshot(elapsed) + self._bar.update(col, completed=completed, success=success, failed=failed) + + def _emit(self) -> None: + current_total = sum(tracker.get_snapshot(0.0)[0] for tracker in self._trackers.values()) + if current_total == self._last_reported_total: + return + self._last_reported_total = current_total + + elapsed = time.perf_counter() - self._start_time + logger.info("📊 Progress [%.1fs]:", elapsed) + for col, tracker in self._trackers.items(): + completed, total_records, _success, _failed, skipped, pct, rate, emoji = tracker.get_snapshot(elapsed) + skipped_suffix = f", {skipped} skipped" if skipped else "" + logger.info( + "%s%s %s: %d/%d (%.0f%%) %.1f rec/s%s", + LOG_INDENT, + emoji, + col, + completed, + total_records, + pct, + rate, + skipped_suffix, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py index 1a07a56b0..73afa2e26 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py @@ -6,9 +6,13 @@ import logging import time from threading import Lock +from typing import TYPE_CHECKING from data_designer.logging import LOG_INDENT, RandomEmoji +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar + logger = logging.getLogger(__name__) @@ -31,21 +35,24 @@ class ProgressTracker: tracker.log_final() """ - def __init__(self, total_records: int, label: str, log_interval_percent: int = 10): - """ - Initialize the progress tracker. - - Args: - total_records: Total number of records to process. - label: Human-readable label for log messages (e.g., "LLM_TEXT column 'response'"). - log_interval_percent: How often to log progress as a percentage (default 10%). - """ + def __init__( + self, + total_records: int, + label: str, + log_interval_percent: int = 10, + *, + quiet: bool = False, + progress_bar: StickyProgressBar | None = None, + progress_bar_key: str | None = None, + ): self.total_records = total_records self.label = label + self.quiet = quiet self.completed = 0 self.success = 0 self.failed = 0 + self.skipped = 0 interval_fraction = max(1, log_interval_percent) / 100.0 self.log_interval = max(1, int(total_records * interval_fraction)) if total_records > 0 else 1 @@ -55,6 +62,11 @@ def __init__(self, total_records: int, label: str, log_interval_percent: int = 1 self.lock = Lock() self._random_emoji = RandomEmoji() + self._bar = progress_bar + self._bar_key = progress_bar_key or label + if self._bar is not None: + self._bar.add_bar(self._bar_key, label, total_records) + def log_start(self, max_workers: int) -> None: """Log the start of processing with worker count and interval information.""" logger.info( @@ -62,6 +74,9 @@ def log_start(self, max_workers: int) -> None: self.label, max_workers, ) + self._log_interval_info() + + def _log_interval_info(self) -> None: interval_str = "after each record" if self.log_interval == 1 else f"every {self.log_interval} records" logger.info( "⏱️ %s will report progress %s", @@ -77,22 +92,34 @@ def record_failure(self) -> None: """Record a failed task completion and log progress if at interval.""" self._record_completion(success=False) + def record_skipped(self) -> None: + """Record a skipped task completion and log progress if at interval.""" + self._record_completion(success=None) + + def get_snapshot(self, elapsed: float | None = None) -> tuple[int, int, int, int, int, float, float, str]: + with self.lock: + return self._get_snapshot_unlocked(elapsed) + def log_final(self) -> None: """Log final progress summary.""" with self.lock: + if self._bar is not None: + self._bar.remove_bar(self._bar_key) if self.completed > 0: self._log_progress_unlocked() - def _record_completion(self, *, success: bool) -> None: + def _record_completion(self, *, success: bool | None) -> None: should_log = False with self.lock: self.completed += 1 - if success: + if success is True: self.success += 1 - else: + elif success is False: self.failed += 1 + else: + self.skipped += 1 - if self.completed >= self.next_log_at and self.completed < self.total_records: + if not self.quiet and self.completed >= self.next_log_at and self.completed < self.total_records: should_log = True while self.next_log_at <= self.completed: self.next_log_at += self.log_interval @@ -101,24 +128,40 @@ def _record_completion(self, *, success: bool) -> None: with self.lock: self._log_progress_unlocked() + def _get_snapshot_unlocked(self, elapsed: float | None = None) -> tuple[int, int, int, int, int, float, float, str]: + current_elapsed = time.perf_counter() - self.start_time if elapsed is None else elapsed + rate = self.completed / current_elapsed if current_elapsed > 0 else 0.0 + percent = (self.completed / self.total_records) * 100 if self.total_records else 100.0 + emoji = self._random_emoji.progress(percent) + return self.completed, self.total_records, self.success, self.failed, self.skipped, percent, rate, emoji + def _log_progress_unlocked(self) -> None: """Log current progress. Must be called while holding the lock.""" - elapsed = time.perf_counter() - self.start_time - rate = self.completed / elapsed if elapsed > 0 else 0.0 - remaining = max(0, self.total_records - self.completed) + if self._bar is not None and self._bar.is_active: + self._bar.update( + self._bar_key, + completed=self.completed, + success=self.success, + failed=self.failed, + ) + return + + completed, total_records, success, failed, skipped, percent, rate, emoji = self._get_snapshot_unlocked() + remaining = max(0, total_records - completed) eta = f"{(remaining / rate):.1f}s" if rate > 0 else "unknown" - percent = (self.completed / self.total_records) * 100 if self.total_records else 100.0 + skipped_suffix = f", {skipped} skipped" if skipped else "" logger.info( - "%s%s %s progress: %d/%d (%.0f%%) complete, %d ok, %d failed, %.2f rec/s, eta %s", + "%s%s %s progress: %d/%d (%.0f%%) complete, %d ok, %d failed%s, %.2f rec/s, eta %s", LOG_INDENT, - self._random_emoji.progress(percent), + emoji, self.label, - self.completed, - self.total_records, + completed, + total_records, percent, - self.success, - self.failed, + success, + failed, + skipped_suffix, rate, eta, ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index b20ce3aca..3adad1456 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -62,6 +62,9 @@ def update_batch(self, row_group: int, column: str, values: list[Any]) -> None: def get_row(self, row_group: int, row_index: int) -> dict[str, Any]: return self._buffers[row_group][row_index] + def has_row_group(self, row_group: int) -> bool: + return row_group in self._buffers + def get_dataframe(self, row_group: int) -> pd.DataFrame: """Return the row group as a DataFrame (excluding dropped rows).""" dropped = self._dropped.get(row_group, set()) @@ -91,6 +94,12 @@ def drop_row(self, row_group: int, row_index: int) -> None: def is_dropped(self, row_group: int, row_index: int) -> bool: return row_index in self._dropped.get(row_group, set()) + def free_row_group(self, row_group: int) -> None: + """Release buffer memory for a row group without writing to disk.""" + self._buffers.pop(row_group, None) + self._dropped.pop(row_group, None) + self._row_group_sizes.pop(row_group, None) + def checkpoint_row_group( self, row_group: int, @@ -117,10 +126,7 @@ def checkpoint_row_group( if on_complete: on_complete(final_path) - # Free memory - del self._buffers[row_group] - self._dropped.pop(row_group, None) - self._row_group_sizes.pop(row_group, None) + self.free_row_group(row_group) def write_metadata(self, target_num_records: int, buffer_size: int) -> None: """Write final metadata after all row groups are checkpointed.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py new file mode 100644 index 000000000..82ca9d72e --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import shutil +import sys +import time +from dataclasses import dataclass, field +from threading import Lock +from typing import TextIO + +BAR_FILLED = "█" +BAR_EMPTY = "░" + + +def _compute_stats_width(total: int) -> int: + """Compute the fixed width of the stats portion based on total records.""" + total_w = len(str(total)) + # " 100% | xxx/xxx | 9999.9 rec/s | eta 999s | xxx failed" + sample = f" 100% | {'9' * total_w}/{total} | 9999.9 rec/s | eta 999s | {'9' * total_w} failed" + return len(sample) + + +@dataclass +class _BarState: + label: str + total: int + completed: int = 0 + success: int = 0 + failed: int = 0 + start_time: float = field(default_factory=time.perf_counter) + stats_width: int = 0 + + def __post_init__(self) -> None: + self.stats_width = _compute_stats_width(self.total) + + +class StickyProgressBar: + """ANSI progress bar that sticks to the bottom of the terminal. + + Log messages (via standard ``logging``) are rendered above the bar + automatically. The bar redraws in-place after each update. + + Usage:: + + with StickyProgressBar() as bar: + bar.add_bar("col_a", "column 'a'", total=100) + for i in range(100): + bar.update("col_a", completed=i + 1, success=i + 1) + bar.remove_bar("col_a") + + Falls back to a no-op on non-TTY streams (CI, pipes, notebooks). + """ + + def __init__(self, stream: TextIO | None = None) -> None: + self._stream = stream or sys.stderr + self._is_tty = hasattr(self._stream, "isatty") and self._stream.isatty() + self._bars: dict[str, _BarState] = {} + self._lock = Lock() + self._drawn_lines = 0 + self._active = False + self._wrapped_handlers: list[tuple[logging.StreamHandler, object]] = [] + + @property + def is_active(self) -> bool: + return self._active + + # -- context manager -- + + def __enter__(self) -> StickyProgressBar: + if self._is_tty: + self._active = True + self._wrap_handlers() + self._write("\033[?25l") # hide cursor + return self + + def __exit__(self, *args: object) -> None: + if self._active: + with self._lock: + self._clear_bars() + self._write("\033[?25h") # show cursor + self._unwrap_handlers() + self._active = False + + # -- public API -- + + def add_bar(self, key: str, label: str, total: int) -> None: + with self._lock: + self._bars[key] = _BarState(label=label, total=total) + if self._active: + self._redraw() + + def update( + self, + key: str, + *, + completed: int, + success: int = 0, + failed: int = 0, + ) -> None: + with self._lock: + if bar := self._bars.get(key): + bar.completed = completed + bar.success = success + bar.failed = failed + if self._active: + self._redraw() + + def remove_bar(self, key: str) -> None: + with self._lock: + self._bars.pop(key, None) + if self._active: + self._redraw() + + # -- handler wrapping -- + + def _wrap_handlers(self) -> None: + """Wrap stderr logging handlers so log lines render above the bars.""" + root = logging.getLogger() + for handler in root.handlers: + if not isinstance(handler, logging.StreamHandler): + continue + if getattr(handler, "stream", None) is not self._stream: + continue + original_emit = handler.emit + + def _make_wrapper(orig: object) -> object: + def wrapped_emit(record: logging.LogRecord) -> None: + with self._lock: + self._clear_bars() + orig(record) # type: ignore[operator] + self._redraw() + + return wrapped_emit + + handler.emit = _make_wrapper(original_emit) # type: ignore[assignment] + self._wrapped_handlers.append((handler, original_emit)) + + def _unwrap_handlers(self) -> None: + for handler, original_emit in self._wrapped_handlers: + handler.emit = original_emit # type: ignore[assignment] + self._wrapped_handlers.clear() + + # -- drawing -- + + def _clear_bars(self) -> None: + """Clear drawn bar lines from the terminal. Caller must hold the lock.""" + if self._drawn_lines > 0: + for _ in range(self._drawn_lines): + self._write("\033[A\033[2K") + self._write("\r\033[2K") + self._drawn_lines = 0 + + def _redraw(self) -> None: + """Redraw all bars. Caller must hold the lock.""" + self._clear_bars() + if not self._bars: + return + width = shutil.get_terminal_size().columns + max_label = max(len(b.label) for b in self._bars.values()) + for bar in self._bars.values(): + line = self._format_bar(bar, width, max_label) + self._write(line + "\n") + self._drawn_lines += 1 + + def _format_bar(self, bar: _BarState, width: int, label_width: int) -> str: + completed = min(bar.completed, bar.total) + pct = (completed / bar.total * 100) if bar.total > 0 else 100.0 + elapsed = time.perf_counter() - bar.start_time + rate = bar.completed / elapsed if elapsed > 0 else 0.0 + remaining = max(0, bar.total - completed) + eta = f"{remaining / rate:.0f}s" if rate > 0 else "?" + + label = bar.label.ljust(label_width) + total_w = len(str(bar.total)) + count_str = f"{completed:>{total_w}}/{bar.total}" + stats = f" {pct:3.0f}% | {count_str} | {rate:6.1f} rec/s | eta {eta:>4s} | {bar.failed:>{total_w}} failed" + stats = stats.ljust(bar.stats_width) + + bar_width = max(10, width - len(label) - bar.stats_width - 4) + filled = int(bar_width * pct / 100) + empty = bar_width - filled + + colored_bar = f"\033[32m{BAR_FILLED * filled}\033[90m{BAR_EMPTY * empty}\033[0m" + return f" {label} {colored_bar}{stats}" + + def _write(self, text: str) -> None: + self._stream.write(text) + self._stream.flush() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 1302fa905..3ec3f2b42 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -22,7 +22,7 @@ FromScratchColumnGenerator, ) from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler -from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -91,13 +91,13 @@ def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: ) def test_validate_async_compatibility(configs: list[Mock], should_raise: bool) -> None: """Validation rejects allow_resize=True with the async engine.""" - builder = Mock(spec=ColumnWiseDatasetBuilder) + builder = Mock(spec=DatasetBuilder) builder.single_column_configs = configs if should_raise: with pytest.raises(DatasetGenerationError, match="allow_resize=True"): - ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + DatasetBuilder._validate_async_compatibility(builder) else: - ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + DatasetBuilder._validate_async_compatibility(builder) # -- _build_async integration test with mock generators ----------------------- @@ -150,7 +150,11 @@ async def test_build_async_end_to_end() -> None: buffer_manager = RowGroupBufferManager(storage) - checkpointed: list[int] = [] + finalized: list[int] = [] + + def finalize_row_group(rg_id: int) -> None: + buffer_manager.checkpoint_row_group(rg_id) + finalized.append(rg_id) scheduler = AsyncTaskScheduler( generators=gen_map, @@ -158,12 +162,12 @@ async def test_build_async_end_to_end() -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, - on_row_group_complete=lambda rg: checkpointed.append(rg), + on_finalize_row_group=finalize_row_group, ) await scheduler.run() - # Both row groups should be checkpointed - assert sorted(checkpointed) == [0, 1] + # Both row groups should be finalized + assert sorted(finalized) == [0, 1] assert buffer_manager.actual_num_records == 4 # All columns should be complete @@ -177,7 +181,7 @@ async def test_build_async_end_to_end() -> None: def test_sync_path_unaffected_by_async_engine_flag() -> None: """DATA_DESIGNER_ASYNC_ENGINE=0 keeps the sync path unchanged.""" - import data_designer.engine.dataset_builders.column_wise_builder as builder_mod + import data_designer.engine.dataset_builders.dataset_builder as builder_mod assert hasattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE") assert isinstance(builder_mod.DATA_DESIGNER_ASYNC_ENGINE, bool) @@ -250,6 +254,7 @@ async def test_checkpoint_produces_correct_parquet_calls() -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, + on_finalize_row_group=lambda rg_id: buffer_manager.checkpoint_row_group(rg_id), ) await scheduler.run() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 93af75b17..dd4cbca9b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -268,13 +268,17 @@ async def test_scheduler_with_buffer_manager() -> None: checkpointed: list[int] = [] + def finalize(rg_id: int) -> None: + buffer_mgr.checkpoint_row_group(rg_id) + checkpointed.append(rg_id) + scheduler = AsyncTaskScheduler( generators=generators, graph=graph, tracker=tracker, row_groups=row_groups, buffer_manager=buffer_mgr, - on_row_group_complete=lambda rg: checkpointed.append(rg), + on_finalize_row_group=finalize, ) await scheduler.run() @@ -511,8 +515,17 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N "downstream": MockCellGenerator(config=_expr_config("downstream"), resource_provider=provider), } - scheduler, tracker = _build_simple_pipeline( - num_records=2, generators=generators, configs=configs, strategies=strategies, trace=True + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + trace=True, + num_records=2, + buffer_size=2, ) await scheduler.run() @@ -524,6 +537,10 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N assert len(downstream_traces) == 0 # Row group is still "complete" (no non-dropped rows remain) assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col", "downstream"]) + assert scheduler._reporter is not None + assert scheduler._reporter._trackers["fail_col"].failed == 2 + assert scheduler._reporter._trackers["downstream"].skipped == 2 + assert scheduler._reporter._trackers["downstream"].completed == 2 @pytest.mark.asyncio(loop_scope="session") @@ -562,7 +579,7 @@ async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() tracker = CompletionTracker.with_graph(graph, row_groups) buffer_mgr = RowGroupBufferManager(storage) - checkpointed: list[int] = [] + finalized: list[int] = [] scheduler = AsyncTaskScheduler( generators=generators, @@ -570,8 +587,10 @@ async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() tracker=tracker, row_groups=row_groups, buffer_manager=buffer_mgr, - on_row_group_complete=lambda rg: checkpointed.append(rg), + on_finalize_row_group=lambda rg: finalized.append(rg), trace=True, + num_records=3, + buffer_size=3, ) await scheduler.run() @@ -579,12 +598,58 @@ async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() for ri in range(3): assert tracker.is_dropped(0, ri) - # Row group still completes (vacuously) and is checkpointed - assert 0 in checkpointed + # Row group is NOT finalized when all rows are dropped (freed instead) + assert 0 not in finalized # full_out was either never dispatched or silently skipped (no KeyError) full_out_errors = [t for t in scheduler.traces if t.column == "full_out" and t.status == "error"] assert len(full_out_errors) == 0 + assert scheduler._reporter is not None + assert scheduler._reporter._trackers["cell_out"].skipped == 3 + assert scheduler._reporter._trackers["cell_out"].completed == 3 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_pre_batch_failure_marks_downstream_tasks_skipped() -> None: + """Pre-batch row-group drops count downstream cell tasks as skipped.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + def fail_pre_batch(row_group: int, row_group_size: int) -> None: + raise ValueError(f"pre-batch failed for {row_group}/{row_group_size}") + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + on_seeds_complete=fail_pre_batch, + num_records=3, + buffer_size=3, + ) + await scheduler.run() + + for row_index in range(3): + assert tracker.is_dropped(0, row_index) + + assert scheduler._reporter is not None + assert scheduler._reporter._trackers["cell_out"].skipped == 3 + assert scheduler._reporter._trackers["cell_out"].completed == 3 @pytest.mark.asyncio(loop_scope="session") @@ -798,8 +863,8 @@ async def test_scheduler_on_before_checkpoint_callback() -> None: @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_on_checkpoint_complete_callback_receives_final_path() -> None: - """on_checkpoint_complete is called with the written parquet file path.""" +async def test_scheduler_on_finalize_row_group_callback_fires() -> None: + """on_finalize_row_group is called for each completed row group.""" provider = _mock_provider() configs = [ SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), @@ -818,7 +883,11 @@ async def test_scheduler_on_checkpoint_complete_callback_receives_final_path() - storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" buffer_mgr = RowGroupBufferManager(storage) - callback_log: list[str] = [] + finalized: list[int] = [] + + def finalize_row_group(rg_id: int) -> None: + buffer_mgr.checkpoint_row_group(rg_id) + finalized.append(rg_id) scheduler = AsyncTaskScheduler( generators=generators, @@ -826,16 +895,17 @@ async def test_scheduler_on_checkpoint_complete_callback_receives_final_path() - tracker=tracker, row_groups=row_groups, buffer_manager=buffer_mgr, - on_checkpoint_complete=lambda path: callback_log.append(path), + on_finalize_row_group=finalize_row_group, ) await scheduler.run() - assert callback_log == ["/fake_final.parquet"] + assert finalized == [0] + assert storage.write_batch_to_parquet_file.call_count == 1 @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_on_checkpoint_complete_skips_empty_row_group() -> None: - """on_checkpoint_complete is not called when a row group writes no file.""" +async def test_scheduler_on_finalize_skips_empty_row_group() -> None: + """on_finalize_row_group is not called when all rows are dropped.""" provider = _mock_provider() storage = MagicMock() storage.dataset_name = "test" @@ -861,7 +931,7 @@ async def test_scheduler_on_checkpoint_complete_skips_empty_row_group() -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_mgr, - on_checkpoint_complete=callback, + on_finalize_row_group=callback, ) await scheduler.run() @@ -979,7 +1049,7 @@ async def test_scheduler_out_of_order_row_group_completion() -> None: row_groups=row_groups, buffer_manager=buffer_mgr, max_concurrent_row_groups=2, - on_row_group_complete=lambda rg_id: checkpoint_order.append(rg_id), + on_finalize_row_group=lambda rg_id: checkpoint_order.append(rg_id), ) await scheduler.run() @@ -1319,3 +1389,44 @@ async def agenerate(self, data: dict) -> dict: assert llm_available == max_llm_wait, ( f"LLM-wait semaphore leaked: available={llm_available}, expected={max_llm_wait}" ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_rg_semaphore_deadlock_with_transient_failures() -> None: + """Row groups stalled by transient failures don't block admission of new row groups. + + Regression test: with max_concurrent_row_groups=1 and 2 row groups, if all + tasks in RG0 fail transiently, the semaphore must still be released so RG1 + can be admitted. The scheduler salvages RG0 inline and continues. + """ + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "col": GenerationStrategy.CELL_BY_CELL, + } + # Fail the first 2 calls (all of RG0), then succeed for everything after. + generators: dict[str, ColumnGenerator] = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "col": MockFailingGenerator(config=_expr_config("col"), resource_provider=provider, transient_failures=2), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=1, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + assert tracker.is_row_group_complete(0, 2, ["seed", "col"]) + assert tracker.is_row_group_complete(1, 2, ["seed", "col"]) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py similarity index 87% rename from packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py rename to packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 095e548e6..477978a38 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -9,6 +9,7 @@ import pytest +import data_designer.engine.dataset_builders.dataset_builder as builder_mod import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig, SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder @@ -19,7 +20,7 @@ from data_designer.config.seed_source import LocalFileSeedSource from data_designer.config.seed_source_dataframe import DataFrameSeedSource from data_designer.engine.column_generators.generators.base import GenerationStrategy -from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.models.errors import ( FormattedLLMErrorMessage, @@ -84,8 +85,8 @@ def stub_batch_manager(): @pytest.fixture -def stub_column_wise_builder(stub_resource_provider, stub_test_config_builder): - return ColumnWiseDatasetBuilder( +def stub_dataset_builder(stub_resource_provider, stub_test_config_builder): + return DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -113,7 +114,7 @@ def builder_with_seed(stub_resource_provider, stub_model_configs, seed_data_setu config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) config_builder.add_column(SamplerColumnConfig(name="extra", sampler_type="uuid", params=UUIDSamplerParams())) - return ColumnWiseDatasetBuilder( + return DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) @@ -130,8 +131,8 @@ def create_mock_processor(name: str, stages: list[str]) -> Mock: return mock_processor -def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_config_builder): - builder = ColumnWiseDatasetBuilder( +def test_dataset_builder_creation(stub_resource_provider, stub_test_config_builder): + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -140,10 +141,10 @@ def test_column_wise_dataset_builder_creation(stub_resource_provider, stub_test_ assert isinstance(builder._registry, DataDesignerRegistry) -def test_column_wise_dataset_builder_creation_with_custom_registry(stub_resource_provider, stub_test_config_builder): +def test_dataset_builder_creation_with_custom_registry(stub_resource_provider, stub_test_config_builder): custom_registry = Mock(spec=DataDesignerRegistry) - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, registry=custom_registry, @@ -152,16 +153,16 @@ def test_column_wise_dataset_builder_creation_with_custom_registry(stub_resource assert builder._registry == custom_registry -def test_column_wise_dataset_builder_artifact_storage_property(stub_column_wise_builder, stub_resource_provider): - assert stub_column_wise_builder.artifact_storage == stub_resource_provider.artifact_storage +def test_dataset_builder_artifact_storage_property(stub_dataset_builder, stub_resource_provider): + assert stub_dataset_builder.artifact_storage == stub_resource_provider.artifact_storage -def test_column_wise_dataset_builder_records_to_drop_initialization(stub_column_wise_builder): - assert stub_column_wise_builder._records_to_drop == set() +def test_dataset_builder_records_to_drop_initialization(stub_dataset_builder): + assert stub_dataset_builder._records_to_drop == set() def test_worker_error_callback_logs_schema_validation_detail( - stub_column_wise_builder: ColumnWiseDatasetBuilder, + stub_dataset_builder: DatasetBuilder, caplog: pytest.LogCaptureFixture, ) -> None: exc = ModelGenerationValidationFailureError( @@ -178,17 +179,17 @@ def test_worker_error_callback_logs_schema_validation_detail( ) with caplog.at_level(logging.WARNING): - stub_column_wise_builder._worker_error_callback(exc, context={"index": 248, "column_name": "test_column"}) + stub_dataset_builder._worker_error_callback(exc, context={"index": 248, "column_name": "test_column"}) assert "record at index 248" in caplog.text assert "column 'test_column'" in caplog.text assert "(schema validation)" in caplog.text assert "Response doesn't match requested 'name' is a required property." in caplog.text - assert 248 in stub_column_wise_builder._records_to_drop + assert 248 in stub_dataset_builder._records_to_drop def test_worker_error_callback_logs_timeout_detail( - stub_column_wise_builder: ColumnWiseDatasetBuilder, + stub_dataset_builder: DatasetBuilder, caplog: pytest.LogCaptureFixture, ) -> None: exc = ModelTimeoutError( @@ -199,7 +200,7 @@ def test_worker_error_callback_logs_timeout_detail( ) with caplog.at_level(logging.WARNING): - stub_column_wise_builder._worker_error_callback(exc, context={"index": 17, "column_name": "test_column"}) + stub_dataset_builder._worker_error_callback(exc, context={"index": 17, "column_name": "test_column"}) assert "record at index 17" in caplog.text assert "column 'test_column'" in caplog.text @@ -207,11 +208,11 @@ def test_worker_error_callback_logs_timeout_detail( assert ( "The request to model 'test-model' timed out while running generation for column 'test_column'." in caplog.text ) - assert 17 in stub_column_wise_builder._records_to_drop + assert 17 in stub_dataset_builder._records_to_drop def test_worker_error_callback_requires_context_index( - stub_column_wise_builder: ColumnWiseDatasetBuilder, + stub_dataset_builder: DatasetBuilder, caplog: pytest.LogCaptureFixture, ) -> None: exc = ModelTimeoutError( @@ -225,15 +226,15 @@ def test_worker_error_callback_requires_context_index( caplog.at_level(logging.WARNING), pytest.raises(RuntimeError, match="Worker error callback called without a valid context index."), ): - stub_column_wise_builder._worker_error_callback(exc, context=None) + stub_dataset_builder._worker_error_callback(exc, context=None) assert "record at index unknown" in caplog.text - assert len(stub_column_wise_builder._records_to_drop) == 0 + assert len(stub_dataset_builder._records_to_drop) == 0 -def test_column_wise_dataset_builder_batch_manager_initialization(stub_column_wise_builder, stub_resource_provider): - assert stub_column_wise_builder.batch_manager is not None - assert stub_column_wise_builder.batch_manager.artifact_storage == stub_resource_provider.artifact_storage +def test_dataset_builder_batch_manager_initialization(stub_dataset_builder, stub_resource_provider): + assert stub_dataset_builder.batch_manager is not None + assert stub_dataset_builder.batch_manager.artifact_storage == stub_resource_provider.artifact_storage @pytest.mark.parametrize( @@ -246,7 +247,7 @@ def test_column_wise_dataset_builder_batch_manager_initialization(stub_column_wi ), ], ) -def test_column_wise_dataset_builder_single_column_configs_property( +def test_dataset_builder_single_column_configs_property( stub_resource_provider, stub_model_configs, config_type, expected_single_configs ): config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) @@ -256,7 +257,7 @@ def test_column_wise_dataset_builder_single_column_configs_property( single_config = expected_single_configs[0] config_builder.add_column(single_config) - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) @@ -271,15 +272,15 @@ def test_column_wise_dataset_builder_single_column_configs_property( sampler_config = expected_single_configs[0] config_builder.add_column(sampler_config) - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) assert builder.single_column_configs == expected_single_configs -def test_column_wise_dataset_builder_build_method_basic_flow( - stub_column_wise_builder, +def test_dataset_builder_build_method_basic_flow( + stub_dataset_builder, stub_batch_manager, stub_resource_provider, ): @@ -298,10 +299,10 @@ def test_column_wise_dataset_builder_build_method_basic_flow( # Mock the batch manager's iter_current_batch method stub_batch_manager.iter_current_batch.return_value = [(0, {"test": "data"})] - stub_column_wise_builder.batch_manager = stub_batch_manager - stub_column_wise_builder.set_processor_runner([]) # No processors for basic flow test + stub_dataset_builder.batch_manager = stub_batch_manager + stub_dataset_builder.set_processor_runner([]) # No processors for basic flow test - result_path = stub_column_wise_builder.build(num_records=100) + result_path = stub_dataset_builder.build(num_records=100) stub_resource_provider.model_registry.run_health_check.assert_called_once() stub_batch_manager.start.assert_called_once_with(num_records=100, buffer_size=50) @@ -319,7 +320,7 @@ def test_column_wise_dataset_builder_build_method_basic_flow( ), ], ) -def test_column_wise_dataset_builder_validate_column_configs( +def test_dataset_builder_validate_column_configs( stub_model_configs, stub_resource_provider, column_configs, expected_error ): config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) @@ -334,7 +335,7 @@ def test_column_wise_dataset_builder_validate_column_configs( mock_registry.column_generators.get_for_config_type.return_value = mock_generator_class with pytest.raises(DatasetGenerationError, match=expected_error): - ColumnWiseDatasetBuilder( + DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, registry=mock_registry, @@ -342,7 +343,7 @@ def test_column_wise_dataset_builder_validate_column_configs( else: # Empty column_configs case - config_builder will fail at build() due to validation with pytest.raises((DatasetGenerationError, Exception)): - ColumnWiseDatasetBuilder( + DatasetBuilder( config_builder=config_builder, resource_provider=stub_resource_provider, ) @@ -353,7 +354,7 @@ def test_run_config_default_non_inference_max_parallel_workers() -> None: assert run_config.non_inference_max_parallel_workers == 4 -@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +@patch("data_designer.engine.dataset_builders.dataset_builder.TelemetryHandler") def test_emit_batch_inference_events_emits_from_deltas( mock_telemetry_handler_class: Mock, stub_resource_provider: Mock, @@ -361,7 +362,7 @@ def test_emit_batch_inference_events_emits_from_deltas( ) -> None: usage_deltas = {"test-model": ModelUsageStats(token_usage=TokenUsageStats(input_tokens=50, output_tokens=150))} - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -390,7 +391,7 @@ def test_emit_batch_inference_events_emits_from_deltas( assert event.output_tokens == 150 -@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +@patch("data_designer.engine.dataset_builders.dataset_builder.TelemetryHandler") def test_emit_batch_inference_events_skips_when_no_deltas( mock_telemetry_handler_class: Mock, stub_resource_provider: Mock, @@ -398,7 +399,7 @@ def test_emit_batch_inference_events_skips_when_no_deltas( ) -> None: usage_deltas: dict[str, ModelUsageStats] = {} - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -409,7 +410,7 @@ def test_emit_batch_inference_events_skips_when_no_deltas( mock_telemetry_handler_class.assert_not_called() -@patch("data_designer.engine.dataset_builders.column_wise_builder.TelemetryHandler") +@patch("data_designer.engine.dataset_builders.dataset_builder.TelemetryHandler") def test_emit_batch_inference_events_handles_multiple_models( mock_telemetry_handler_class: Mock, stub_resource_provider: Mock, @@ -420,7 +421,7 @@ def test_emit_batch_inference_events_handles_multiple_models( "model-b": ModelUsageStats(token_usage=TokenUsageStats(input_tokens=50, output_tokens=75)), } - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -446,7 +447,7 @@ def test_emit_batch_inference_events_handles_multiple_models( (False, 0.5, 0.5, 10), # defaults ], ) -@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor") +@patch("data_designer.engine.dataset_builders.dataset_builder.ConcurrentThreadExecutor") def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provider( mock_executor_class: Mock, stub_resource_provider: Mock, @@ -470,7 +471,7 @@ def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provide for processor_config in stub_test_processor_configs: config_builder.add_processor(processor_config) - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) @@ -497,13 +498,13 @@ def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provide assert call_kwargs["disable_early_shutdown"] == disable_early_shutdown -@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor") +@patch("data_designer.engine.dataset_builders.dataset_builder.ConcurrentThreadExecutor") def test_fan_out_with_threads_passes_column_name_in_context( mock_executor_class: Mock, stub_resource_provider: Mock, stub_test_config_builder: DataDesignerConfigBuilder, ) -> None: - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -532,13 +533,13 @@ def test_fan_out_with_threads_passes_column_name_in_context( ] -@patch("data_designer.engine.dataset_builders.column_wise_builder.AsyncConcurrentExecutor", create=True) +@patch("data_designer.engine.dataset_builders.dataset_builder.AsyncConcurrentExecutor", create=True) def test_fan_out_with_async_passes_column_name_in_context( mock_executor_class: Mock, stub_resource_provider: Mock, stub_test_config_builder: DataDesignerConfigBuilder, ) -> None: - builder = ColumnWiseDatasetBuilder( + builder = DatasetBuilder( data_designer_config=stub_test_config_builder.build(), resource_provider=stub_resource_provider, ) @@ -585,12 +586,56 @@ def bad_fn(df: pd.DataFrame) -> pd.DataFrame: config = DataDesignerConfigBuilder(model_configs=stub_model_configs) config.add_column(SamplerColumnConfig(name="some_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())) config.add_column(CustomColumnConfig(name="col", generator_function=bad_fn, generation_strategy="full_column")) - builder = ColumnWiseDatasetBuilder(data_designer_config=config.build(), resource_provider=stub_resource_provider) + builder = DatasetBuilder(data_designer_config=config.build(), resource_provider=stub_resource_provider) with pytest.raises(DatasetGenerationError, match=r"(?s)Failed to process column 'col'.*something broke"): builder.build_preview(num_records=3) +def test_build_async_preview_returns_empty_dataframe_when_row_group_is_already_freed( + stub_resource_provider, + stub_test_config_builder, + monkeypatch: pytest.MonkeyPatch, +) -> None: + builder = DatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + class StubScheduler: + traces: list[object] = [] + + async def run(self) -> None: + return None + + class MockFuture: + def result(self) -> None: + return None + + def mock_run_coroutine_threadsafe(coro, loop): + coro.close() + return MockFuture() + + scheduler = StubScheduler() + buffer_manager = Mock() + buffer_manager.has_row_group.return_value = False + + monkeypatch.setattr(builder, "_prepare_async_run", Mock(return_value=(scheduler, buffer_manager))) + monkeypatch.setattr(builder_mod, "ensure_async_engine_loop", lambda: object(), raising=False) + monkeypatch.setattr( + builder_mod, + "asyncio", + Mock(run_coroutine_threadsafe=mock_run_coroutine_threadsafe), + raising=False, + ) + + result = builder._build_async_preview([], num_records=3) + + assert result.empty + buffer_manager.get_dataframe.assert_not_called() + buffer_manager.free_row_group.assert_not_called() + + # Processor tests @@ -599,14 +644,14 @@ def simple_builder(stub_resource_provider, stub_model_configs): """Minimal builder with a single UUID column and no batch files on disk.""" config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) config_builder.add_column(SamplerColumnConfig(name="id", sampler_type="uuid", params=UUIDSamplerParams())) - return ColumnWiseDatasetBuilder( + return DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) -def test_initialize_processors(stub_column_wise_builder): - processors = stub_column_wise_builder.processors +def test_initialize_processors(stub_dataset_builder): + processors = stub_dataset_builder.processors assert isinstance(processors, tuple) assert len(processors) == 1 assert processors[0].config.column_names == ["column_to_drop"] @@ -804,12 +849,12 @@ def _resize_columns(spec: str) -> list[CustomColumnConfig]: def _build_resize_builder(stub_resource_provider, stub_model_configs, seed_data_setup, columns): - """Build a ColumnWiseDatasetBuilder with the given resize column configs.""" + """Build a DatasetBuilder with the given resize column configs.""" config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) config_builder.with_seed_dataset(LocalFileSeedSource(path=str(seed_data_setup["seed_path"]))) for col in columns: config_builder.add_column(col) - return ColumnWiseDatasetBuilder( + return DatasetBuilder( data_designer_config=config_builder.build(), resource_provider=stub_resource_provider, ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py index dcfaec0d1..13b698a22 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py @@ -22,6 +22,7 @@ def test_initializes_with_correct_values() -> None: assert tracker.completed == 0 assert tracker.success == 0 assert tracker.failed == 0 + assert tracker.skipped == 0 def test_calculates_log_interval_from_percentage() -> None: @@ -80,6 +81,15 @@ def test_record_failure_multiple_times(tracker: ProgressTracker) -> None: assert tracker.failed == 5 +def test_record_skipped_increments_completed_and_skipped(tracker: ProgressTracker) -> None: + tracker.record_skipped() + + assert tracker.completed == 1 + assert tracker.success == 0 + assert tracker.failed == 0 + assert tracker.skipped == 1 + + def test_tracks_mixed_successes_and_failures(tracker: ProgressTracker) -> None: tracker.record_success() tracker.record_success() @@ -92,6 +102,24 @@ def test_tracks_mixed_successes_and_failures(tracker: ProgressTracker) -> None: assert tracker.failed == 2 +def test_get_snapshot_includes_skipped_counts() -> None: + tracker = ProgressTracker(total_records=10, label="test") + tracker.record_success() + tracker.record_failure() + tracker.record_skipped() + + completed, total_records, success, failed, skipped, percent, rate, emoji = tracker.get_snapshot(elapsed=2.0) + + assert completed == 3 + assert total_records == 10 + assert success == 1 + assert failed == 1 + assert skipped == 1 + assert percent == 30.0 + assert rate == 1.5 + assert emoji + + def test_log_start_logs_worker_info(tracker: ProgressTracker, caplog: pytest.LogCaptureFixture) -> None: with caplog.at_level(logging.INFO): tracker.log_start(max_workers=8) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 189ddcac7..37b6f71ac 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -26,6 +26,18 @@ def test_init_row_group() -> None: assert row == {} +def test_has_row_group() -> None: + mgr = RowGroupBufferManager(_mock_artifact_storage()) + + assert not mgr.has_row_group(0) + + mgr.init_row_group(0, 1) + assert mgr.has_row_group(0) + + mgr.free_row_group(0) + assert not mgr.has_row_group(0) + + def test_update_cell() -> None: mgr = RowGroupBufferManager(_mock_artifact_storage()) mgr.init_row_group(0, 2) @@ -176,6 +188,28 @@ def test_replace_dataframe_fewer_rows_marks_trailing_dropped() -> None: assert len(result_df) == 2 +def test_free_row_group_releases_memory() -> None: + """free_row_group releases buffer memory without writing to disk.""" + storage = _mock_artifact_storage() + mgr = RowGroupBufferManager(storage) + mgr.init_row_group(0, 3) + mgr.update_batch(0, "col", ["a", "b", "c"]) + mgr.drop_row(0, 1) + + mgr.free_row_group(0) + + with pytest.raises(KeyError): + mgr.get_row(0, 0) + storage.write_batch_to_parquet_file.assert_not_called() + assert mgr.actual_num_records == 0 + + +def test_free_row_group_idempotent() -> None: + """free_row_group on a non-existent row group is a no-op.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.free_row_group(99) # should not raise + + def test_checkpoint_calls_on_complete_when_all_rows_dropped() -> None: storage = _mock_artifact_storage() callback = Mock() diff --git a/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py index 685c92890..5cba94623 100644 --- a/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py +++ b/packages/data-designer-engine/tests/engine/models/test_async_engine_switch.py @@ -8,7 +8,7 @@ import pytest from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.models.facade import ModelFacade @@ -28,7 +28,7 @@ def test_model_facade_has_sync_methods() -> None: def test_async_engine_env_controls_builder_execution_path(monkeypatch: pytest.MonkeyPatch) -> None: """When DATA_DESIGNER_ASYNC_ENGINE is set, _run_cell_by_cell_generator dispatches to async fan-out.""" - import data_designer.engine.dataset_builders.column_wise_builder as cwb_module + import data_designer.engine.dataset_builders.dataset_builder as cwb_module mock_generator = MagicMock() mock_generator.get_generation_strategy.return_value = GenerationStrategy.CELL_BY_CELL @@ -39,7 +39,7 @@ def test_async_engine_env_controls_builder_execution_path(monkeypatch: pytest.Mo # Test with async enabled — uses max_parallel_requests from generator (same as sync) with patch.object(cwb_module, "DATA_DESIGNER_ASYNC_ENGINE", True): - ColumnWiseDatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) + DatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) builder._fan_out_with_async.assert_called_once_with(mock_generator, max_workers=4) builder._fan_out_with_threads.assert_not_called() @@ -47,6 +47,6 @@ def test_async_engine_env_controls_builder_execution_path(monkeypatch: pytest.Mo # Test with async disabled — uses max_parallel_requests from generator with patch.object(cwb_module, "DATA_DESIGNER_ASYNC_ENGINE", False): - ColumnWiseDatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) + DatasetBuilder._run_cell_by_cell_generator(builder, mock_generator) builder._fan_out_with_threads.assert_called_once_with(mock_generator, max_workers=4) builder._fan_out_with_async.assert_not_called() diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index c195a174d..e487074a5 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -35,7 +35,7 @@ from data_designer.config.utils.info import InfoType, InterfaceInfo from data_designer.engine.analysis.dataset_profiler import DataDesignerDatasetProfiler, DatasetProfilerConfig from data_designer.engine.compiler import compile_data_designer_config -from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.mcp.io import list_tool_names from data_designer.engine.model_provider import resolve_model_provider_registry from data_designer.engine.resources.person_reader import ( @@ -430,8 +430,8 @@ def _create_dataset_builder( self, data_designer_config: DataDesignerConfig, resource_provider: ResourceProvider, - ) -> ColumnWiseDatasetBuilder: - return ColumnWiseDatasetBuilder( + ) -> DatasetBuilder: + return DatasetBuilder( data_designer_config=data_designer_config, resource_provider=resource_provider, ) diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index c7c5190d4..62c7ed3eb 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -562,7 +562,7 @@ def test_preview_raises_generation_error_when_dataset_is_empty( ) with patch( - "data_designer.engine.dataset_builders.column_wise_builder.ColumnWiseDatasetBuilder.process_preview", + "data_designer.engine.dataset_builders.dataset_builder.DatasetBuilder.process_preview", return_value=lazy.pd.DataFrame(), ): with pytest.raises(DataDesignerGenerationError, match="Dataset is empty"): diff --git a/scripts/benchmarks/benchmark_engine_v2.py b/scripts/benchmarks/benchmark_engine_v2.py index 286fe48f4..2b4d19523 100644 --- a/scripts/benchmarks/benchmark_engine_v2.py +++ b/scripts/benchmarks/benchmark_engine_v2.py @@ -591,7 +591,7 @@ def _dataset_fingerprint(df: pd.DataFrame) -> str: def _run_single_benchmark(settings: BenchmarkSettings, engine_mode: str) -> BenchmarkResult: # Imports are deferred so engine selection respects DATA_DESIGNER_ASYNC_ENGINE. from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage - from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder + from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder from data_designer.engine.model_provider import resolve_model_provider_registry from data_designer.engine.resources.resource_provider import create_resource_provider from data_designer.engine.resources.seed_reader import SeedReaderRegistry @@ -636,7 +636,7 @@ def _run_single_benchmark(settings: BenchmarkSettings, engine_mode: str) -> Benc mcp_providers=[mcp_provider], tool_configs=builder.tool_configs, ) - dataset_builder = ColumnWiseDatasetBuilder( + dataset_builder = DatasetBuilder( data_designer_config=builder.build(), resource_provider=resource_provider, )