diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 932c7fae7..4bb497a9d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -4,15 +4,20 @@ from __future__ import annotations import asyncio +import concurrent.futures import functools import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload from data_designer.config.column_configs import GenerationStrategy from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT +_T = TypeVar("_T") + +_SYNC_BRIDGE_TIMEOUT = 300 + if TYPE_CHECKING: import pandas as pd @@ -23,33 +28,84 @@ logger = logging.getLogger(__name__) +def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: + """Run an async coroutine from sync context. + + - No running event loop → ``asyncio.run(coro)`` + - Running event loop (e.g. notebook/service) → run in a background thread + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(asyncio.run, coro) + timed_out = False + try: + result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) + except concurrent.futures.TimeoutError as exc: + timed_out = True + logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") + raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc + finally: + pool.shutdown(wait=not timed_out, cancel_futures=timed_out) + return result + + class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): @property def can_generate_from_scratch(self) -> bool: return False + @property + def is_order_dependent(self) -> bool: + """Whether this generator's output depends on prior row-group calls. + + Example: SeedDatasetColumnGenerator tracks its position in the seed + dataset, so row group N must complete before N+1 starts. + """ + return False + + def _is_overridden(self, method_name: str) -> bool: + """Check if a subclass has overridden a base ColumnGenerator method.""" + return getattr(type(self), method_name) is not getattr(ColumnGenerator, method_name) + @staticmethod @abstractmethod def get_generation_strategy() -> GenerationStrategy: ... @overload - @abstractmethod def generate(self, data: dict) -> dict: ... @overload - @abstractmethod def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... - @abstractmethod - def generate(self, data: DataT) -> DataT: ... + def generate(self, data: DataT) -> DataT: + """Sync generate — overridden by most concrete generators. + + Default bridges to ``agenerate()`` for async-first subclasses that only + implement ``agenerate()``. Raises ``NotImplementedError`` if neither + ``generate()`` nor ``agenerate()`` is overridden. + """ + if not self._is_overridden("agenerate"): + raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") + return _run_coroutine_sync(self.agenerate(data)) - async def agenerate(self, data: dict) -> dict: - """Async fallback — delegates to sync generate via thread pool. + @overload + async def agenerate(self, data: dict) -> dict: ... + + @overload + async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: ... + + async def agenerate(self, data: DataT) -> DataT: + """Async generate — delegates to sync ``generate()`` via thread pool. Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) should override this with a direct async implementation. """ - return await asyncio.to_thread(self.generate, data) + if not self._is_overridden("generate"): + raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") + return await asyncio.to_thread(self.generate, data.copy()) def log_pre_generation(self) -> None: """A shared method to log info before the generator's `generate` method is called. @@ -68,6 +124,10 @@ def can_generate_from_scratch(self) -> bool: @abstractmethod def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... + async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: + """Async wrapper — wraps sync ``generate_from_scratch()`` in a thread.""" + return await asyncio.to_thread(self.generate_from_scratch, num_records) + class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): @property diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index c8ba53ab5..4874f77b8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import inspect import logging from typing import TYPE_CHECKING, Any @@ -65,12 +66,57 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict return self._generate(data, is_dataframe) + async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]: + """Async generate — branches on strategy and detects coroutine functions.""" + is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN + if is_full_column: + return await asyncio.to_thread(self.generate, data.copy()) + # The @custom_column_generator decorator wraps the user function in a sync + # wrapper, so we must unwrap to detect async functions. + fn_unwrapped = inspect.unwrap(self.config.generator_function) + if asyncio.iscoroutinefunction(fn_unwrapped): + missing = set(self.config.required_columns) - set(data.keys()) + if missing: + raise CustomColumnGenerationError( + f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}" + ) + keys_before = set(data.keys()) + + try: + result = await self._ainvoke_generator_function(data) + except CustomColumnGenerationError: + raise + except Exception as e: + logger.warning( + f"⚠️ Custom generator function {self.config.generator_function.__name__!r} " + f"failed for column '{self.config.name}'. This record will be skipped.\n{e}" + ) + raise CustomColumnGenerationError( + f"Custom generator function failed for column '{self.config.name}': {e}" + ) from e + + return self._postprocess_result(result, is_dataframe=False, keys_before=keys_before) + return await asyncio.to_thread(self.generate, data) + + async def _ainvoke_generator_function(self, data: dict) -> dict | pd.DataFrame: + """Invoke an async user generator function with appropriate arguments. + + The @custom_column_generator decorator's sync wrapper returns a coroutine + when the original function is async, so we await the wrapper's return value. + """ + params = self._get_validated_params() + fn = self.config.generator_function + if len(params) == 1: + return await fn(data) + elif len(params) == 2: + return await fn(data, self.config.generator_params) + else: + models = self._build_models_dict() + return await fn(data, self.config.generator_params, models) + def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]: """Unified generation logic for both strategies.""" - # Get columns/keys using unified accessor get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys())) - expected_type = lazy.pd.DataFrame if is_dataframe else dict - type_name = "DataFrame" if is_dataframe else "dict" # Check required columns missing = set(self.config.required_columns) - get_keys(data) @@ -96,6 +142,15 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. f"Custom generator function failed for column '{self.config.name}': {e}" ) from e + return self._postprocess_result(result, is_dataframe, keys_before) + + def _postprocess_result( + self, + result: dict | pd.DataFrame | list[dict], + is_dataframe: bool, + keys_before: set[str], + ) -> dict | pd.DataFrame | list[dict]: + """Validate type and output columns of a generation result.""" # Cell-by-cell with allow_resize: accept dict or list[dict] if not is_dataframe and self.config.allow_resize: if isinstance(result, dict): @@ -113,6 +168,8 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. ) # Validate return type for non-resize paths + expected_type = lazy.pd.DataFrame if is_dataframe else dict + type_name = "DataFrame" if is_dataframe else "dict" if not isinstance(result, expected_type): raise CustomColumnGenerationError( f"Custom generator for column '{self.config.name}' must return a {type_name}, " diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py index 83b13ffd9..82eaf795b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py @@ -27,9 +27,19 @@ class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]): def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - def generate(self, data: dict) -> dict: + def _prepare_embedding_inputs(self, data: dict) -> list[str]: deserialized_record = deserialize_json_values(data) - input_texts = parse_list_string(deserialized_record[self.config.target_column]) + return parse_list_string(deserialized_record[self.config.target_column]) + + def generate(self, data: dict) -> dict: + input_texts = self._prepare_embedding_inputs(data) embeddings = self.model.generate_text_embeddings(input_texts=input_texts) data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") return data + + async def agenerate(self, data: dict) -> dict: + """Native async generate using model.agenerate_text_embeddings.""" + input_texts = self._prepare_embedding_inputs(data) + embeddings = await self.model.agenerate_text_embeddings(input_texts=input_texts) + data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") + return data diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 31095c490..2060afb64 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING from data_designer.config.column_configs import ImageColumnConfig @@ -31,46 +32,42 @@ def media_storage(self) -> MediaStorage: def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - def generate(self, data: dict) -> dict: - """Generate image(s) and optionally save to disk. - - Args: - data: Record data - - Returns: - Record with image path(s) (create mode) or base64 data (preview mode) added - """ + def _prepare_image_inputs(self, data: dict) -> tuple[str, list[dict] | None]: + """Validate inputs and render prompt for image generation.""" deserialized_record = deserialize_json_values(data) - - # Validate required columns missing_columns = list(set(self.config.required_columns) - set(data.keys())) if len(missing_columns) > 0: - error_msg = ( + raise ValueError( f"There was an error preparing the Jinja2 expression template. " f"The following columns {missing_columns} are missing!" ) - raise ValueError(error_msg) - - # Render prompt template self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) prompt = self.render_template(deserialized_record) - - # Validate prompt is non-empty if not prompt or not prompt.strip(): raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") - - # Process multi-modal context if provided multi_modal_context = self._build_multi_modal_context(deserialized_record) + return prompt, multi_modal_context - # Generate images (returns list of base64 strings) + def generate(self, data: dict) -> dict: + """Generate image(s) and optionally save to disk.""" + prompt, multi_modal_context = self._prepare_image_inputs(data) base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context) - - # Store via media storage (mode determines disk vs dataframe storage) - # Use column name as subfolder to organize images results = [ self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) for base64_image in base64_images ] data[self.config.name] = results + return data + async def agenerate(self, data: dict) -> dict: + """Native async generate using model.agenerate_image.""" + prompt, multi_modal_context = self._prepare_image_inputs(data) + base64_images = await self.model.agenerate_image(prompt=prompt, multi_modal_context=multi_modal_context) + results = await asyncio.to_thread( + lambda: [ + self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) + for base64_image in base64_images + ] + ) + data[self.config.name] = results return data diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py index b71876fdf..a310ca3c8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -29,6 +29,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.FULL_COLUMN + @property + def is_order_dependent(self) -> bool: + return True + @property def num_records_sampled(self) -> int: return self._num_records_sampled diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py new file mode 100644 index 000000000..7eff29fec --- /dev/null +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py @@ -0,0 +1,424 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import ( + CustomColumnConfig, + EmbeddingColumnConfig, + ExpressionColumnConfig, + GenerationStrategy, + ImageColumnConfig, +) +from data_designer.config.custom_column import custom_column_generator +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + ColumnGeneratorFullColumn, + FromScratchColumnGenerator, + _run_coroutine_sync, +) +from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator +from data_designer.engine.column_generators.generators.embedding import ( + EmbeddingCellGenerator, + EmbeddingGenerationResult, +) +from data_designer.engine.column_generators.generators.image import ImageCellGenerator +from data_designer.engine.column_generators.generators.llm_completion import ( + ColumnGeneratorWithModelChatCompletion, +) +from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator +from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError +from data_designer.engine.resources.resource_provider import ResourceProvider + +# -- Helpers ----------------------------------------------------------------- + + +def _mock_provider() -> Mock: + return Mock(spec=ResourceProvider) + + +def _make_expr_config(name: str = "test") -> ExpressionColumnConfig: + return ExpressionColumnConfig(name=name, expr="{{ col1 }}", dtype="str") + + +# -- _run_coroutine_sync tests ----------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_coroutine_sync_with_running_loop() -> None: + """When called inside a running event loop, runs coroutine in a new thread.""" + + async def add(a: int, b: int) -> int: + return a + b + + result = _run_coroutine_sync(add(1, 2)) + assert result == 3 + + +def test_run_coroutine_sync_from_sync_context() -> None: + """When called from sync context (no loop), uses asyncio.run.""" + + async def double(x: int) -> int: + return x * 2 + + result = _run_coroutine_sync(double(5)) + assert result == 10 + + +# -- is_order_dependent default ---------------------------------------------------- + + +def test_is_order_dependent_default_false() -> None: + class SyncGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + return data + + gen = SyncGen(config=_make_expr_config(), resource_provider=_mock_provider()) + assert gen.is_order_dependent is False + + +# -- Symmetric bridging: sync-only generator called via agenerate ----------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sync_only_generator_agenerate() -> None: + """Sync-only generator can be called via agenerate().""" + + class SyncOnlyGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data["result"] = "sync" + return data + + gen = SyncOnlyGen(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate({"col1": "x"}) + assert result["result"] == "sync" + + +# -- Symmetric bridging: async-only generator called via generate ----------- + + +def test_async_only_generator_generate() -> None: + """Async-only generator can be called via generate() from sync context.""" + + class AsyncOnlyGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + async def agenerate(self, data: dict) -> dict: + data["result"] = "async" + return data + + gen = AsyncOnlyGen(config=_make_expr_config(), resource_provider=_mock_provider()) + result = gen.generate({"col1": "x"}) + assert result["result"] == "async" + + +# -- Neither overridden raises NotImplementedError -------------------------- + + +def test_neither_generate_nor_agenerate_raises() -> None: + """If neither generate() nor agenerate() is overridden, generate() raises.""" + + class BareGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + gen = BareGen(config=_make_expr_config(), resource_provider=_mock_provider()) + with pytest.raises(NotImplementedError, match="must implement either"): + gen.generate({"col1": "x"}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_neither_generate_nor_agenerate_raises_from_async() -> None: + """If neither is overridden, agenerate() raises directly without thread bounce.""" + + class BareGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + gen = BareGen(config=_make_expr_config(), resource_provider=_mock_provider()) + with pytest.raises(NotImplementedError, match="must implement either"): + await gen.agenerate({"col1": "x"}) + + +# -- FromScratchColumnGenerator async wrappers -------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_from_scratch_agenerate_from_scratch() -> None: + """FromScratchColumnGenerator.agenerate_from_scratch wraps sync correctly.""" + + class TestFromScratch(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({"val": list(range(num_records))}) + + gen = TestFromScratch(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate_from_scratch(3) + assert len(result) == 3 + assert list(result["val"]) == [0, 1, 2] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_from_scratch_agenerate_passes_copy() -> None: + """FromScratchColumnGenerator.agenerate passes df.copy() to thread.""" + original = lazy.pd.DataFrame({"col1": [1, 2, 3]}) + received_data: list[lazy.pd.DataFrame] = [] + + class TestFromScratch(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + received_data.append(data) + data["new_col"] = "added" + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame() + + gen = TestFromScratch(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate(original) + + # Original should not be mutated + assert "new_col" not in original.columns + assert "new_col" in result.columns + + +# -- ColumnGeneratorFullColumn async wrapper ---------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_full_column_agenerate_passes_copy() -> None: + """ColumnGeneratorFullColumn.agenerate passes df.copy() to thread.""" + original = lazy.pd.DataFrame({"col1": ["a", "b"]}) + + class TestFullCol(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + data["added"] = True + return data + + gen = TestFullCol(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate(original) + + assert "added" not in original.columns + assert "added" in result.columns + + +# -- SeedDatasetColumnGenerator is_order_dependent ----------------------------------- + + +def test_seed_dataset_is_order_dependent() -> None: + gen = object.__new__(SeedDatasetColumnGenerator) + assert gen.is_order_dependent is True + + +# -- CustomColumnGenerator agenerate branching -------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_sync_function() -> None: + """Sync custom function is wrapped in asyncio.to_thread via agenerate.""" + + @custom_column_generator() + def sync_fn(row: dict) -> dict: + row["sync_col"] = "hello" + return row + + config = CustomColumnConfig(name="sync_col", generator_function=sync_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + result = await gen.agenerate({"input": "val"}) + assert result["sync_col"] == "hello" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_function() -> None: + """Async custom function is called directly as coroutine.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + row["async_col"] = "async_hello" + return row + + config = CustomColumnConfig(name="async_col", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + result = await gen.agenerate({"input": "val"}) + assert result["async_col"] == "async_hello" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_full_column_wraps_in_thread() -> None: + """Full-column custom generator wraps in asyncio.to_thread with df.copy().""" + + @custom_column_generator() + def full_col_fn(df: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + df["fc_col"] = "batch" + return df + + config = CustomColumnConfig( + name="fc_col", + generator_function=full_col_fn, + generation_strategy=GenerationStrategy.FULL_COLUMN, + ) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + + original = lazy.pd.DataFrame({"input": [1, 2]}) + result = await gen.agenerate(original) + + # Should not mutate the original since we pass .copy() in agenerate + assert "fc_col" not in original.columns + assert "fc_col" in result.columns + + +# -- Existing generators still work unchanged ---------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_llm_completion_agenerate_still_works() -> None: + """Verify LLM completion generators still have working agenerate (from PR #280).""" + assert hasattr(ColumnGeneratorWithModelChatCompletion, "agenerate") + # The agenerate is a custom implementation, not the base default + assert ColumnGeneratorWithModelChatCompletion.agenerate is not ColumnGenerator.agenerate + + +# -- Async custom generator error path parity --------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_required_columns() -> None: + """Async custom generator raises on missing required_columns.""" + + @custom_column_generator(required_columns=["input"]) + async def async_fn(row: dict) -> dict: + row["result"] = row["input"].upper() + return row + + config = CustomColumnConfig(name="result", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="Missing required columns"): + await gen.agenerate({"other": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_output_column() -> None: + """Async custom generator raises when expected output column is missing.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + row["wrong"] = "value" + return row + + config = CustomColumnConfig(name="expected", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="did not create the expected column"): + await gen.agenerate({"input": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_side_effect_column() -> None: + """Async custom generator raises when declared side_effect column is missing.""" + + @custom_column_generator(side_effect_columns=["secondary"]) + async def async_fn(row: dict) -> dict: + row["primary"] = 1 + return row + + config = CustomColumnConfig(name="primary", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="did not create declared side_effect_columns"): + await gen.agenerate({"input": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_allow_resize_invalid_list() -> None: + """Async custom generator with allow_resize rejects invalid non-dict list items.""" + + @custom_column_generator(required_columns=["x"]) + async def async_fn(row: dict) -> list: + return [1, 2] + + config = CustomColumnConfig( + name="out", + generator_function=async_fn, + allow_resize=True, + ) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="list elements must be dicts"): + await gen.agenerate({"x": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_wraps_exception() -> None: + """Async custom generator wraps user exceptions in CustomColumnGenerationError.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + raise ValueError("async boom") + + config = CustomColumnConfig(name="result", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="Custom generator function failed"): + await gen.agenerate({"input": 1}) + + +# -- ImageCellGenerator async ------------------------------------------------ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_image_agenerate(stub_resource_provider: Mock) -> None: + """ImageCellGenerator.agenerate calls model.agenerate_image.""" + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = ["images/img1.png", "images/img2.png"] + stub_resource_provider.artifact_storage.media_storage = mock_storage + + config = ImageColumnConfig(name="test_image", prompt="A {{ style }} image", model_alias="test_model") + gen = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with patch.object(gen, "model") as mock_model: + mock_model.agenerate_image = AsyncMock(return_value=["b64_1", "b64_2"]) + result = await gen.agenerate({"style": "photorealistic"}) + + assert result["test_image"] == ["images/img1.png", "images/img2.png"] + mock_model.agenerate_image.assert_awaited_once() + + +# -- EmbeddingCellGenerator async -------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_embedding_agenerate(stub_resource_provider: Mock) -> None: + """EmbeddingCellGenerator.agenerate calls model.agenerate_text_embeddings.""" + config = EmbeddingColumnConfig(name="test_emb", target_column="text", model_alias="test_model") + gen = EmbeddingCellGenerator(config=config, resource_provider=stub_resource_provider) + + stub_embeddings = [[0.1, 0.2], [0.3, 0.4]] + with patch.object(gen, "model") as mock_model: + mock_model.agenerate_text_embeddings = AsyncMock(return_value=stub_embeddings) + result = await gen.agenerate({"text": "['hello', 'world']"}) + + expected = EmbeddingGenerationResult(embeddings=stub_embeddings).model_dump(mode="json") + assert result["test_emb"] == expected + mock_model.agenerate_text_embeddings.assert_awaited_once_with(input_texts=["hello", "world"]) diff --git a/plans/346/async-generators-and-task-queue.md b/plans/346/async-generators-and-task-queue.md index 0ac4e7ba7..73ec28a80 100644 --- a/plans/346/async-generators-and-task-queue.md +++ b/plans/346/async-generators-and-task-queue.md @@ -228,17 +228,25 @@ graph. - [x] `ExecutionGraph` class: - Backing stores: `dict[str, set[str]]` column → upstream columns; `dict[str, GenerationStrategy]` column → generation strategy - - `upstream(column: str) -> set[str]` — direct dependencies of a column - - `downstream(column: str) -> set[str]` — columns that depend on this one (for error attribution) - - `strategy(column: str) -> GenerationStrategy` — cell-by-cell or full-column - - `topological_order() -> list[str]` — valid DAG execution order (used by scheduler and for validation) - - `critical_path() -> list[str]` — longest dependency chain (useful for ETA estimates) - - `task_count(num_records: int, buffer_size: int) -> dict[str, int]` — exact task count per + - `get_upstream_columns(column: str) -> set[str]` — direct dependencies of a column + - `get_downstream_columns(column: str) -> set[str]` — columns that depend on this one (for error attribution) + - `get_strategy(column: str) -> GenerationStrategy` — cell-by-cell or full-column + - `get_topological_order() -> list[str]` — valid DAG execution order (cached; used by scheduler and for validation) + - `get_longest_dependency_chain() -> list[str]` — longest dependency chain by column count (useful for ETA estimates) + - `get_root_columns() -> list[str]` — columns with no upstream deps, in topological order + - `split_upstream_by_strategy(column: str) -> tuple[list[str], list[str]]` — splits + upstream into (batch/full-column, cell-by-cell) groups; cached per column + - `compute_task_count(num_records: int, buffer_size: int) -> dict[str, int]` — exact task count per column before the run starts; cell-by-cell columns produce `num_records` tasks, full-column columns (including from-scratch generators, which report `FULL_COLUMN`) produce `ceil(num_records / buffer_size)` tasks + - `compute_cell_dependencies(column, row_group, row_index | None, row_group_size) -> list[SliceRef]` + — derives cell-level deps on demand from column-level DAG + strategy - `to_mermaid() -> str` — Mermaid diagram string; nodes are annotated with strategy type -- [x] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: + - `columns` property — all column names in insertion order + - `add_column(name, strategy)` / `add_edge(upstream, downstream)` — low-level construction + - `set_side_effect(side_effect_col, producer)` / `resolve_side_effect(column) -> str` — side-effect mapping +- [x] `ExecutionGraph.create(column_configs, strategies)` classmethod factory: - Input: the ordered list of `ColumnConfigT` / `MultiColumnConfig`, plus a pre-computed strategy map (available from generators at builder init time via `get_generation_strategy()`) - For each config, read `config.required_columns` → set of upstream column names @@ -246,26 +254,42 @@ graph. and map them back to their producer column, so downstream references resolve correctly - For `MultiColumnConfig`, all sub-columns share the same dependencies - Validate: every required column must resolve to a known producer (including - registered side-effect outputs), and the graph must be acyclic -- [x] Unit tests for graph construction, validation, critical path, task count, and Mermaid output + registered side-effect outputs), and the graph must be acyclic (raises `DAGCircularDependencyError`) +- [x] Unit tests for graph construction, validation, longest chain, task count, cell deps, and Mermaid output **Files**: new module `engine/dataset_builders/utils/execution_graph.py`, tests ### Step 2: Completion Tracker -A lightweight structure tracking which (column, row_group, row_index) tuples are -done. Row indices are **local** to their row group (0-based within each group), -matching the buffer manager's per-row-group addressing. +A frontier-based tracker tracking which (column, row_group, row_index) tuples are +done and maintaining a ready-to-dispatch frontier. Row indices are **local** to their +row group (0-based within each group), matching the buffer manager's per-row-group addressing. - [x] `CompletionTracker` class: - Internal: `dict[int, dict[str, set[int]]]` mapping row_group → column → set of completed local row indices - - `mark_complete(column: str, row_group: int, row_index: int)` / `mark_batch_complete(column: str, row_group: int, row_group_size: int)` - - `is_ready(column: str, row_group: int, row_index: int, graph: ExecutionGraph) -> bool` — checks all upstream columns for that (row_group, row_index) - - `is_batch_ready(column: str, row_group: int, row_group_size: int, graph: ExecutionGraph) -> bool` — checks all rows in group - - `drop_row(row_group: int, row_index: int)` — marks row as dropped across all columns; - `get_ready_tasks` skips dropped rows, in-flight tasks for dropped rows are ignored on completion - - `is_row_group_complete(row_group: int, row_group_size: int, all_columns: list[str]) -> bool` — all non-dropped rows have all columns done; `row_group_size` is the original size, dropped rows (via `drop_row`) are excluded internally - - `get_ready_tasks(graph: ExecutionGraph, row_groups, dispatched: set[Task]) -> list[Task]` — yields all currently dispatchable tasks, excluding dropped rows and already-dispatched/in-flight tasks; reads `graph.strategy(column)` to determine task granularity per column + - `with_graph(graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> CompletionTracker` — + classmethod factory that creates a frontier-enabled tracker; seeds the frontier with root tasks + - `mark_cell_complete(column, row_group, row_index)` — marks a cell done, discards it + from the frontier, and calls `_enqueue_downstream` to add newly-ready tasks + - `mark_row_range_complete(column, row_group, row_group_size)` — marks an entire batch done, + validates row-group size consistency, and enqueues downstream + - `is_complete(ref: SliceRef) -> bool` — check if a single cell is complete + - `is_all_complete(cells: list[SliceRef]) -> bool` — check if all given cells/batches are complete + - `drop_row(row_group, row_index)` — marks row as dropped; removes cell tasks for that row + from the frontier; calls `_reevaluate_batch_tasks` since dropping a row may unblock + full-column downstream tasks + - `is_dropped(row_group, row_index) -> bool` + - `is_row_group_complete(row_group, row_group_size, all_columns) -> bool` — all non-dropped rows have all columns done + - `get_ready_tasks(dispatched: set[Task]) -> list[Task]` — returns all currently dispatchable + tasks from the frontier, excluding already-dispatched/in-flight tasks; O(frontier) not O(C × R) + - Internal frontier management: + - `_seed_frontier()` — populates frontier with root column tasks (from `graph.get_root_columns()`) + - `_enqueue_downstream(column, row_group, row_index | None)` — on completion, checks each + downstream column's readiness using `split_upstream_by_strategy`; adds ready tasks to frontier + - `_reevaluate_batch_tasks(row_group)` — after row drop, checks if any full-column tasks + became ready (all non-dropped rows now complete) + - Strategy validation: `mark_cell_complete` requires `CELL_BY_CELL`, `mark_row_range_complete` + requires `FULL_COLUMN`; mismatches raise `ValueError` - [x] No locks needed: all access is from the single asyncio event loop thread - [x] Unit tests @@ -273,61 +297,79 @@ matching the buffer manager's per-row-group addressing. ### Step 3: Task Model -Simple dataclass representing a unit of work. +Simple dataclasses representing units of work and cell-level references. -- [x] `Task` dataclass: +- [x] `SliceRef` dataclass (frozen, ordered): + - `column: str`, `row_group: int`, `row_index: int | None = None` + - Reference to a cell or full row group in the execution grid + - Used by `ExecutionGraph.compute_cell_dependencies()` and `CompletionTracker.is_complete()` +- [x] `Task` dataclass (frozen): - `column: str` - `row_group: int` - `row_index: int | None` (None for batch tasks) - `task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"]` -- [x] `TaskResult` with status, output, error info +- [x] `TaskResult` dataclass: + - `task: Task`, `status: Literal["success", "error"]`, `output: Any`, `error: Exception | None` + - `retryable: bool = False` — whether the failure can be retried by the salvage loop - [x] `TaskTrace` dataclass (only instantiated when tracing is enabled): - `column: str`, `row_group: int`, `row_index: int | None`, `task_type: str` - `dispatched_at: float` — `perf_counter()` when `create_task()` fires - `slot_acquired_at: float` — after execution semaphore acquired - `completed_at: float` — in `finally` block after generator returns - `status: str`, `error: str | None` + - `from_task(task: Task) -> TaskTrace` classmethod factory - [x] Hashable so we can track dispatched/pending sets +- [x] `DAGCircularDependencyError` in `errors.py` — raised by `ExecutionGraph.get_topological_order()` -**Files**: new module `engine/dataset_builders/utils/task_model.py` — must be its own module -since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference `Task`/`TaskResult`; -inlining would create import cycles. +**Files**: new module `engine/dataset_builders/utils/task_model.py` (+ `errors.py`) — must be its own +module since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference +`Task`/`TaskResult`/`SliceRef`; inlining would create import cycles. ### Step 4: Async Task Scheduler The core orchestrator that replaces `_run_batch` for the async path. - [ ] `AsyncTaskScheduler` class: - - Constructor takes: generators (by column name), `graph: ExecutionGraph`, completion tracker, row group definitions, concurrency limit (`async_scheduler_max_submitted_tasks`), row group semaphore (`async_max_concurrent_row_groups`), salvage config, error/result callbacks, `trace: bool = False` - - When `trace=True`, populates `scheduler.traces: list[TaskTrace]` (one record per task); otherwise no `TaskTrace` objects are created. See Profiling. + - Constructor takes: generators (by column name), `graph: ExecutionGraph`, row group + definitions (`list[tuple[int, int]]`), concurrency limit (`async_scheduler_max_submitted_tasks`), + row group semaphore (`async_max_concurrent_row_groups`), salvage config, error/result + callbacks, `trace: bool = False` + - Initializes `CompletionTracker.with_graph(graph, row_groups)` — the tracker seeds + its frontier with root tasks automatically + - When `trace=True`, populates `scheduler.traces: list[TaskTrace]` (one record per task, + created via `TaskTrace.from_task()`); otherwise no `TaskTrace` objects are created. See Profiling. - `async run()` — main loop: 1. Acquire the row group semaphore (`async_max_concurrent_row_groups`) before admitting a new row group's seed tasks. Dispatch `from_scratch` tasks, respecting `is_stateful`: stateful generators serialize per-instance (row group N's seed completes before N+1's seed starts for that generator); stateless generators dispatch all admitted row groups concurrently - 2. Loop: query `completion_tracker.get_ready_tasks()` → dispatch each via - `asyncio.create_task()` behind submission budget → on completion, update - tracker → repeat until all tasks done or early shutdown + 2. Loop: pull from `tracker.get_ready_tasks(dispatched)` → dispatch each via + `asyncio.create_task()` behind submission budget → on completion, call + `tracker.mark_cell_complete()` or `tracker.mark_row_range_complete()` (the tracker's + internal `_enqueue_downstream` auto-populates the frontier with newly-ready tasks) + → repeat until all tasks done or early shutdown 3. When ready queue drains, run salvage rounds over deferred retryable failures - (up to `async_salvage_max_rounds` rounds) - 4. After each row group completes: run post-batch processors, checkpoint + (up to `async_salvage_max_rounds` rounds); check `TaskResult.retryable` to classify + 4. After each row group completes (check via `tracker.is_row_group_complete()`): + run post-batch processors, checkpoint - Task dispatch follows the pattern from §4: acquire execution slot → prepare → release → await throttle (LLM only) → reacquire → execute + writeback → release - Admission control: never allow more than `async_scheduler_max_submitted_tasks` - tasks in submitted/running/waiting states; hold remaining ready tasks in the - scheduler queue until slots free up - - Error handling: classify failures as retryable vs non-retryable; retryable - go to deferred queue with backoff; same early-shutdown logic as - `AsyncConcurrentExecutor` (error rate threshold within sliding window) + tasks in submitted/running/waiting states; remove tasks from `dispatched` set on + completion; hold remaining ready tasks in the scheduler queue until slots free up + - Error handling: classify failures as retryable vs non-retryable (set `TaskResult.retryable`); + retryable go to deferred queue with backoff; non-retryable trigger `tracker.drop_row()` + which auto-removes cell tasks from frontier and re-evaluates batch readiness; + same early-shutdown logic as `AsyncConcurrentExecutor` (error rate threshold within sliding window) - Progress tracking: create one `ProgressTracker` per column for accounting (success/failure counts, rate, ETA), but suppress per-completion interval logs in async mode. A separate background coroutine (`asyncio.create_task`) emits a single consolidated summary line every 10 seconds across all active columns; it is cancelled once all tasks complete. See UX Considerations. - [ ] Use `asyncio.Event` to wake the scheduler when a task completes (avoids polling). - `Event` is sufficient — the scheduler resets it and re-checks ready tasks on each wake; - `Condition` would be needed only if waiting on a specific predicate, which the tracker + `Event` is sufficient — the scheduler resets it and re-checks `get_ready_tasks` on each wake; + `Condition` would be needed only if waiting on a specific predicate, which the frontier already handles. - [ ] Unit tests with mock generators @@ -352,38 +394,40 @@ This means sync-first generators (most built-ins, existing plugins) work unchang and async-first generators (new plugins doing native async I/O) only need to implement `agenerate()` without writing a redundant sync version. -- [ ] Add symmetric bridging on the base `ColumnGenerator`: +- [x] Add symmetric bridging on the base `ColumnGenerator`: - `agenerate()` default: `asyncio.to_thread(self.generate, data)` (already exists) - `generate()` default: call a safe sync runner helper that: - uses `asyncio.run()` if no loop is running in the current thread - otherwise submits to the background loop with `run_coroutine_threadsafe(...).result(timeout=...)` - Detect which one the subclass overrides to avoid infinite recursion -- [ ] Add `is_stateful` property to base `ColumnGenerator` (default `False`). + - **Note**: v1 uses ThreadPoolExecutor fallback instead of builder's background loop (available in PR 4) +- [x] Add `is_stateful` property to base `ColumnGenerator` (default `False`). Stateful generators are serialized per-instance by the scheduler. -- [ ] `ColumnGeneratorWithModelChatCompletion.agenerate` — already implemented (PR #280), no changes needed -- [ ] `FromScratchColumnGenerator`: add both async wrappers — `async agenerate_from_scratch(num_records) -> DataFrame` +- [x] `ColumnGeneratorWithModelChatCompletion.agenerate` — already implemented (PR #280), no changes needed +- [x] `FromScratchColumnGenerator`: add both async wrappers — `async agenerate_from_scratch(num_records) -> DataFrame` (wraps `generate_from_scratch` in `asyncio.to_thread`) and `async agenerate(data: DataFrame) -> DataFrame` (wraps `generate` in `asyncio.to_thread` with defensive `df.copy()`). Both are needed because the scheduler dispatches subclasses via either path depending on whether the buffer is empty. -- [ ] `ColumnGeneratorFullColumn`: add `async agenerate(data: DataFrame) -> DataFrame` — wraps sync in +- [x] `ColumnGeneratorFullColumn`: add `async agenerate(data: DataFrame) -> DataFrame` — wraps sync in `asyncio.to_thread` with defensive `df.copy()` (see Risks). This intentionally overrides the base `ColumnGenerator.agenerate(dict)` with a DataFrame-typed signature; the scheduler dispatches the correct variant based on generation strategy. -- [ ] `ExpressionColumnGenerator`: inherits full-column async wrapper -- [ ] `SamplerColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = False` -- [ ] `SeedDatasetColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = True` (maintains DuckDB batch reader cursor and leftover-row buffer) -- [ ] `ValidationColumnGenerator`: inherits full-column async wrapper. Note: for `REMOTE` validators +- [x] `ExpressionColumnGenerator`: inherits full-column async wrapper +- [x] `SamplerColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = False` +- [x] `SeedDatasetColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = True` (maintains DuckDB batch reader cursor and leftover-row buffer) +- [x] `ValidationColumnGenerator`: inherits full-column async wrapper. Note: for `REMOTE` validators with `max_parallel_requests > 1`, `generate()` internally uses `ConcurrentThreadExecutor`, so the async wrapper spawns a thread that itself spawns more threads — bypassing the scheduler's concurrency controls for those HTTP calls. Acceptable for v1 (see Follow-ups). -- [ ] `CustomColumnGenerator`: inherits directly from `ColumnGenerator` (not from +- [x] `CustomColumnGenerator`: inherits directly from `ColumnGenerator` (not from `ColumnGeneratorFullColumn`), so it does not automatically inherit the full-column async wrapper. Needs its own `agenerate` that branches on strategy: - `CELL_BY_CELL`: if the user function is a coroutine (`asyncio.iscoroutinefunction`), call it directly; otherwise wrap in `asyncio.to_thread` - `FULL_COLUMN`: wrap `generate(DataFrame)` in `asyncio.to_thread` with defensive `df.copy()` `is_stateful` defaults to `False`; custom implementations can override it. -- [ ] `ImageCellGenerator`, `EmbeddingCellGenerator`: add native `agenerate` using `model.agenerate_image` / `model.agenerate_text_embeddings` + - **Note**: uses `inspect.unwrap()` to detect async through the `@custom_column_generator` decorator wrapper +- [x] `ImageCellGenerator`, `EmbeddingCellGenerator`: add native `agenerate` using `model.agenerate_image` / `model.agenerate_text_embeddings` **Files**: `generators/base.py`, `generators/expression.py`, `generators/samplers.py`, `generators/seed_dataset.py`, `generators/image.py`, `generators/embedding.py`, tests @@ -410,9 +454,12 @@ Adapt `DatasetBatchManager` for concurrent row group processing. Wire the new scheduler into `ColumnWiseDatasetBuilder`. - [ ] New method `_build_async(generators, num_records, buffer_size, ...)`: - 1. Build `ExecutionGraph` from `self._column_configs` and generator strategies - 2. Partition rows into row groups - 3. Create `CompletionTracker`, `AsyncTaskScheduler` + 1. Build `ExecutionGraph.create(self._column_configs, strategies)` from configs and + generator strategies; catch `DAGCircularDependencyError` and `ValueError` and + re-raise as `DatasetGenerationError` with context + 2. Partition rows into row groups as `list[tuple[int, int]]` (rg_id, rg_size) + 3. Create `AsyncTaskScheduler` (which internally creates + `CompletionTracker.with_graph(graph, row_groups)`) 4. Run scheduler on the background event loop (reuse `_ensure_async_engine_loop()` from `dataset_builders/utils/async_concurrency.py` — already exists) 5. Scheduler handles checkpointing via callbacks @@ -431,33 +478,35 @@ Wire the new scheduler into `ColumnWiseDatasetBuilder`. Tests are added incrementally with each PR, not deferred to the end. -**PR 1 (foundation) — unit tests**: -- [x] Execution graph construction, validation, topological order, critical path +**PR 1 (foundation) — unit tests** (merged): +- [x] Execution graph construction, validation, `get_topological_order`, `get_longest_dependency_chain` - [x] Execution graph: side-effect output columns resolve correctly (e.g., column depending on `summary__trace` maps to a dependency on the `summary` generator) -- [x] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, +- [x] Execution graph: `compute_cell_dependencies` returns correct deps for cell-by-cell, full-column, and from-scratch columns -- [x] Execution graph: `task_count` and `to_mermaid` output -- [x] Completion tracker: `mark_complete`, `is_complete`, `all_complete` -- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` -- [x] Task model: hashability, equality, TaskResult, TaskTrace +- [x] Execution graph: `compute_task_count`, `split_upstream_by_strategy`, and `to_mermaid` output +- [x] Completion tracker: `mark_cell_complete`, `mark_row_range_complete`, `is_complete`, `is_all_complete` +- [x] Completion tracker: frontier-based `get_ready_tasks` with `with_graph` initialization +- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete`, `_reevaluate_batch_tasks` +- [x] Task model: hashability, equality, TaskResult (including `retryable`), TaskTrace, SliceRef **PR 2 (generators) — unit tests**: -- [ ] Symmetric bridging: sync-only generator can be called via `agenerate` -- [ ] Symmetric bridging: async-only generator can be called via `generate` -- [ ] `is_stateful` defaults to `False`; `SeedDatasetColumnGenerator` returns `True` -- [ ] `FromScratchColumnGenerator.agenerate_from_scratch` wraps sync correctly -- [ ] `ColumnGeneratorFullColumn.agenerate` passes `df.copy()` to thread -- [ ] `CustomColumnGenerator.agenerate` detects coroutine functions and calls directly -- [ ] All existing generator tests pass unchanged (`make test`) +- [x] Symmetric bridging: sync-only generator can be called via `agenerate` +- [x] Symmetric bridging: async-only generator can be called via `generate` +- [x] `is_stateful` defaults to `False`; `SeedDatasetColumnGenerator` returns `True` +- [x] `FromScratchColumnGenerator.agenerate_from_scratch` wraps sync correctly +- [x] `ColumnGeneratorFullColumn.agenerate` passes `df.copy()` to thread +- [x] `CustomColumnGenerator.agenerate` detects coroutine functions and calls directly +- [x] All existing generator tests pass unchanged (`make test`) **PR 3 (scheduler + buffer) — unit tests with mock generators**: -- [ ] Scheduler dispatches from-scratch tasks first, then downstream as deps complete +- [ ] Scheduler dispatches root tasks first (from `tracker.get_ready_tasks`), + then downstream as deps complete (via tracker's `_enqueue_downstream`) - [ ] Stateful generator serializes across row groups; stateless runs concurrently -- [ ] Retry salvage: transient failure is retried and succeeds; - non-retryable failure drops immediately; retry budget exhaustion drops correctly -- [ ] Eager row-drop: failure on column B drops the row across all columns, - independent column C does not process the dropped row +- [ ] Retry salvage: transient failure (`TaskResult.retryable=True`) is retried and succeeds; + non-retryable failure triggers `tracker.drop_row()` immediately; retry budget exhaustion drops correctly +- [ ] Eager row-drop: failure on column B calls `tracker.drop_row()` which removes + cell tasks for that row from frontier; independent column C does not process the dropped row - [ ] Row-drop with in-flight full-column task: completed task may still compute dropped rows, but writeback is suppressed and row remains dropped - [ ] Bounded submission: submitted task count never exceeds @@ -489,9 +538,10 @@ The implementation steps map to 4 PRs that can be reviewed and merged independen Each PR is self-contained: it adds new modules with full test coverage but does not change existing behavior until the final integration PR. -### PR 1: Foundation (Steps 1 + 2 + 3) +### PR 1: Foundation (Steps 1 + 2 + 3) — MERGED as [#356](https://github.com/NVIDIA-NeMo/DataDesigner/pull/356) -**Scope**: `ExecutionGraph`, `CompletionTracker`, `Task`/`TaskResult`/`TaskTrace` dataclasses. +**Scope**: `ExecutionGraph`, `CompletionTracker`, `SliceRef`/`Task`/`TaskResult`/`TaskTrace` +dataclasses, `DAGCircularDependencyError`. All three are pure data structures with no side effects on the existing codebase. They live in new modules under `engine/dataset_builders/utils/` and are only imported @@ -500,16 +550,18 @@ by code introduced in later PRs. - `execution_graph.py` + tests - `completion_tracker.py` + tests - `task_model.py` + tests +- `errors.py` (`DAGCircularDependencyError`) **Why grouped**: the three are tightly coupled (the tracker takes the graph to resolve readiness, the task model is the unit of work for both), small individually, and have no external dependencies. Splitting them into 3 separate PRs would create review overhead without meaningful isolation benefit. -**What works after merge**: you can build an `ExecutionGraph` from any existing config, -inspect it (`topological_order`, `critical_path`, `task_count`, `to_mermaid`), query -cell-level dependencies, and track completion state — all in isolation, with full test -coverage. No runtime behavior changes. +**What works after merge**: you can build an `ExecutionGraph.create()` from any existing config, +inspect it (`get_topological_order`, `get_longest_dependency_chain`, `compute_task_count`, +`to_mermaid`), query cell-level dependencies via `compute_cell_dependencies()`, and track +completion state with the frontier-enabled `CompletionTracker.with_graph()` — all in +isolation, with full test coverage. No runtime behavior changes. **Can merge independently**: yes — no existing code imports these modules. @@ -540,13 +592,21 @@ Existing sync callers are unaffected. **Scope**: `AsyncTaskScheduler`, row group buffer manager. -- `async_scheduler.py` + tests (uses graph, tracker, and task model from PR 1) +- `async_scheduler.py` + tests (uses `ExecutionGraph.create()`, + `CompletionTracker.with_graph()`, `Task`, `TaskResult`, `TaskTrace`, `SliceRef` from PR 1) - Buffer manager extension in `dataset_batch_manager.py` + tests - Retry/salvage logic, progress consolidation, error handling -**Depends on**: PR 1 (imports `ExecutionGraph`, `CompletionTracker`, `Task`), PR 2 +**Depends on**: PR 1 (imports `ExecutionGraph`, `CompletionTracker`, `Task`, `SliceRef`), PR 2 (calls `agenerate` / `agenerate_from_scratch`, reads `is_stateful`). +**Key integration with PR 1's frontier model**: The scheduler initializes +`CompletionTracker.with_graph(graph, row_groups)` which auto-seeds the frontier with +root tasks. The main loop pulls from `tracker.get_ready_tasks(dispatched)`, and on task +completion calls `mark_cell_complete()` / `mark_row_range_complete()` which internally +enqueues newly-ready downstream tasks. On row drop, calls `tracker.drop_row()` which +removes frontier tasks and re-evaluates batch readiness. + **What works after merge**: the scheduler can be instantiated with mock generators and driven through its full lifecycle in tests — row group admission, dependency-driven dispatch, retry/salvage, row drops, checkpoint callbacks. The buffer manager supports @@ -829,11 +889,13 @@ mid-run loses at most one batch. **`ExecutionTraits` replaced by `GenerationStrategy` on the graph.** PR #269 attaches an `ExecutionTraits` flag enum (`CELL`, `BARRIER`, `ROW_STREAMABLE`) to each node. Since our graph is column-level, we store `GenerationStrategy` (cell-by-cell, full-column) directly -on each column node instead. From-scratch columns are identified by having no upstream -dependencies in the graph; the scheduler checks `can_generate_from_scratch` on the generator -instance to determine which method to call. This serves the same purpose as `ExecutionTraits` -— the scheduler and `CompletionTracker` use it to determine task granularity — without -needing typed node IDs or flag combinations. +on each column node instead (accessible via `get_strategy()`). From-scratch columns are +identified by having no upstream dependencies in the graph (via `get_root_columns()`); the +scheduler checks `can_generate_from_scratch` on the generator instance to determine which +method to call. The `split_upstream_by_strategy()` method provides cached separation of +upstream deps by strategy type, used by the tracker's frontier logic. This serves the same +purpose as `ExecutionTraits` — the scheduler and `CompletionTracker` use it to determine +task granularity — without needing typed node IDs or flag combinations. **`ROW_STREAMABLE` trait omitted.** PR #269 introduces `is_row_streamable` so full-column generators that process rows independently (e.g., `ExpressionColumnGenerator`) can be