From 99750d60c1ef64b6a8975f8ac031920cee61edbc Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 6 Mar 2026 18:03:07 -0300 Subject: [PATCH 1/8] feat: add async generator migration with symmetric bridging and statefulness - Symmetric generate/agenerate bridging in base ColumnGenerator - is_stateful property; SeedDatasetColumnGenerator declares True - Async wrappers for FromScratchColumnGenerator and ColumnGeneratorFullColumn - Native async paths for ImageCellGenerator and EmbeddingCellGenerator - CustomColumnGenerator.agenerate with full validation parity - Extract _postprocess_result for shared sync/async output validation --- .../column_generators/generators/base.py | 59 ++- .../column_generators/generators/custom.py | 63 ++- .../column_generators/generators/embedding.py | 8 + .../column_generators/generators/image.py | 33 ++ .../generators/seed_dataset.py | 4 + .../generators/test_async_generators.py | 364 ++++++++++++++++++ plans/346/async-generators-and-task-queue.md | 232 +++++++---- 7 files changed, 669 insertions(+), 94 deletions(-) create mode 100644 packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 932c7fae7..5ce859432 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -4,15 +4,20 @@ from __future__ import annotations import asyncio +import concurrent.futures import functools import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload from data_designer.config.column_configs import GenerationStrategy from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT +_T = TypeVar("_T") + +_SYNC_BRIDGE_TIMEOUT = 300 + if TYPE_CHECKING: import pandas as pd @@ -23,28 +28,58 @@ logger = logging.getLogger(__name__) +def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: + """Run an async coroutine from sync context. + + - No running event loop → ``asyncio.run(coro)`` + - Running event loop (e.g. notebook/service) → run in a background thread + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result(timeout=_SYNC_BRIDGE_TIMEOUT) + + class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): @property def can_generate_from_scratch(self) -> bool: return False + @property + def is_stateful(self) -> bool: + """Whether this generator maintains state across calls. + + Stateful generators are serialized per-instance by the async scheduler + (row group N must complete before N+1 starts for that generator). + """ + return False + @staticmethod @abstractmethod def get_generation_strategy() -> GenerationStrategy: ... @overload - @abstractmethod def generate(self, data: dict) -> dict: ... @overload - @abstractmethod def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... - @abstractmethod - def generate(self, data: DataT) -> DataT: ... + def generate(self, data: DataT) -> DataT: + """Sync generate — overridden by most concrete generators. + + Default bridges to ``agenerate()`` for async-first subclasses that only + implement ``agenerate()``. Raises ``NotImplementedError`` if neither + ``generate()`` nor ``agenerate()`` is overridden. + """ + if type(self).agenerate is ColumnGenerator.agenerate: + raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") + return _run_coroutine_sync(self.agenerate(data)) async def agenerate(self, data: dict) -> dict: - """Async fallback — delegates to sync generate via thread pool. + """Async generate — delegates to sync ``generate()`` via thread pool. Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) should override this with a direct async implementation. @@ -68,6 +103,14 @@ def can_generate_from_scratch(self) -> bool: @abstractmethod def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... + async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: + """Async wrapper — wraps sync ``generate_from_scratch()`` in a thread.""" + return await asyncio.to_thread(self.generate_from_scratch, num_records) + + async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: + """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" + return await asyncio.to_thread(self.generate, data.copy()) + class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): @property @@ -155,3 +198,7 @@ def get_generation_strategy() -> GenerationStrategy: @abstractmethod def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... + + async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: + """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" + return await asyncio.to_thread(self.generate, data.copy()) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index c8ba53ab5..4874f77b8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import inspect import logging from typing import TYPE_CHECKING, Any @@ -65,12 +66,57 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict return self._generate(data, is_dataframe) + async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]: + """Async generate — branches on strategy and detects coroutine functions.""" + is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN + if is_full_column: + return await asyncio.to_thread(self.generate, data.copy()) + # The @custom_column_generator decorator wraps the user function in a sync + # wrapper, so we must unwrap to detect async functions. + fn_unwrapped = inspect.unwrap(self.config.generator_function) + if asyncio.iscoroutinefunction(fn_unwrapped): + missing = set(self.config.required_columns) - set(data.keys()) + if missing: + raise CustomColumnGenerationError( + f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}" + ) + keys_before = set(data.keys()) + + try: + result = await self._ainvoke_generator_function(data) + except CustomColumnGenerationError: + raise + except Exception as e: + logger.warning( + f"⚠️ Custom generator function {self.config.generator_function.__name__!r} " + f"failed for column '{self.config.name}'. This record will be skipped.\n{e}" + ) + raise CustomColumnGenerationError( + f"Custom generator function failed for column '{self.config.name}': {e}" + ) from e + + return self._postprocess_result(result, is_dataframe=False, keys_before=keys_before) + return await asyncio.to_thread(self.generate, data) + + async def _ainvoke_generator_function(self, data: dict) -> dict | pd.DataFrame: + """Invoke an async user generator function with appropriate arguments. + + The @custom_column_generator decorator's sync wrapper returns a coroutine + when the original function is async, so we await the wrapper's return value. + """ + params = self._get_validated_params() + fn = self.config.generator_function + if len(params) == 1: + return await fn(data) + elif len(params) == 2: + return await fn(data, self.config.generator_params) + else: + models = self._build_models_dict() + return await fn(data, self.config.generator_params, models) + def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]: """Unified generation logic for both strategies.""" - # Get columns/keys using unified accessor get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys())) - expected_type = lazy.pd.DataFrame if is_dataframe else dict - type_name = "DataFrame" if is_dataframe else "dict" # Check required columns missing = set(self.config.required_columns) - get_keys(data) @@ -96,6 +142,15 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. f"Custom generator function failed for column '{self.config.name}': {e}" ) from e + return self._postprocess_result(result, is_dataframe, keys_before) + + def _postprocess_result( + self, + result: dict | pd.DataFrame | list[dict], + is_dataframe: bool, + keys_before: set[str], + ) -> dict | pd.DataFrame | list[dict]: + """Validate type and output columns of a generation result.""" # Cell-by-cell with allow_resize: accept dict or list[dict] if not is_dataframe and self.config.allow_resize: if isinstance(result, dict): @@ -113,6 +168,8 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. ) # Validate return type for non-resize paths + expected_type = lazy.pd.DataFrame if is_dataframe else dict + type_name = "DataFrame" if is_dataframe else "dict" if not isinstance(result, expected_type): raise CustomColumnGenerationError( f"Custom generator for column '{self.config.name}' must return a {type_name}, " diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py index 83b13ffd9..88ac05f81 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py @@ -33,3 +33,11 @@ def generate(self, data: dict) -> dict: embeddings = self.model.generate_text_embeddings(input_texts=input_texts) data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") return data + + async def agenerate(self, data: dict) -> dict: + """Native async generate using model.agenerate_text_embeddings.""" + deserialized_record = deserialize_json_values(data) + input_texts = parse_list_string(deserialized_record[self.config.target_column]) + embeddings = await self.model.agenerate_text_embeddings(input_texts=input_texts) + data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") + return data diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 31095c490..34fadb627 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING from data_designer.config.column_configs import ImageColumnConfig @@ -74,3 +75,35 @@ def generate(self, data: dict) -> dict: data[self.config.name] = results return data + + async def agenerate(self, data: dict) -> dict: + """Native async generate using model.agenerate_image.""" + deserialized_record = deserialize_json_values(data) + + missing_columns = list(set(self.config.required_columns) - set(data.keys())) + if len(missing_columns) > 0: + raise ValueError( + f"There was an error preparing the Jinja2 expression template. " + f"The following columns {missing_columns} are missing!" + ) + + self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) + prompt = self.render_template(deserialized_record) + + if not prompt or not prompt.strip(): + raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") + + multi_modal_context = self._build_multi_modal_context(deserialized_record) + + base64_images = await self.model.agenerate_image(prompt=prompt, multi_modal_context=multi_modal_context) + + # media_storage.save_base64_image is sync I/O — wrap in thread + results = await asyncio.to_thread( + lambda: [ + self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) + for base64_image in base64_images + ] + ) + data[self.config.name] = results + + return data diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py index 51ab41d8f..33fb020ba 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -29,6 +29,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.FULL_COLUMN + @property + def is_stateful(self) -> bool: + return True + @property def num_records_sampled(self) -> int: return self._num_records_sampled diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py new file mode 100644 index 000000000..4d0574daa --- /dev/null +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py @@ -0,0 +1,364 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import ( + CustomColumnConfig, + ExpressionColumnConfig, + GenerationStrategy, +) +from data_designer.config.custom_column import custom_column_generator +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + ColumnGeneratorFullColumn, + FromScratchColumnGenerator, + _run_coroutine_sync, +) +from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator +from data_designer.engine.column_generators.generators.llm_completion import ( + ColumnGeneratorWithModelChatCompletion, +) +from data_designer.engine.column_generators.generators.seed_dataset import SeedDatasetColumnGenerator +from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError +from data_designer.engine.resources.resource_provider import ResourceProvider + +# -- Helpers ----------------------------------------------------------------- + + +def _mock_provider() -> Mock: + return Mock(spec=ResourceProvider) + + +def _make_expr_config(name: str = "test") -> ExpressionColumnConfig: + return ExpressionColumnConfig(name=name, expr="{{ col1 }}", dtype="str") + + +# -- _run_coroutine_sync tests ----------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_coroutine_sync_with_running_loop() -> None: + """When called inside a running event loop, runs coroutine in a new thread.""" + + async def add(a: int, b: int) -> int: + return a + b + + result = _run_coroutine_sync(add(1, 2)) + assert result == 3 + + +def test_run_coroutine_sync_from_sync_context() -> None: + """When called from sync context (no loop), uses asyncio.run.""" + + async def double(x: int) -> int: + return x * 2 + + result = _run_coroutine_sync(double(5)) + assert result == 10 + + +# -- is_stateful default ---------------------------------------------------- + + +def test_is_stateful_default_false() -> None: + class SyncGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + return data + + gen = SyncGen(config=_make_expr_config(), resource_provider=_mock_provider()) + assert gen.is_stateful is False + + +# -- Symmetric bridging: sync-only generator called via agenerate ----------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sync_only_generator_agenerate() -> None: + """Sync-only generator can be called via agenerate().""" + + class SyncOnlyGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data["result"] = "sync" + return data + + gen = SyncOnlyGen(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate({"col1": "x"}) + assert result["result"] == "sync" + + +# -- Symmetric bridging: async-only generator called via generate ----------- + + +def test_async_only_generator_generate() -> None: + """Async-only generator can be called via generate() from sync context.""" + + class AsyncOnlyGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + async def agenerate(self, data: dict) -> dict: + data["result"] = "async" + return data + + gen = AsyncOnlyGen(config=_make_expr_config(), resource_provider=_mock_provider()) + result = gen.generate({"col1": "x"}) + assert result["result"] == "async" + + +# -- Neither overridden raises NotImplementedError -------------------------- + + +def test_neither_generate_nor_agenerate_raises() -> None: + """If neither generate() nor agenerate() is overridden, generate() raises.""" + + class BareGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + gen = BareGen(config=_make_expr_config(), resource_provider=_mock_provider()) + with pytest.raises(NotImplementedError, match="must implement either"): + gen.generate({"col1": "x"}) + + +# -- FromScratchColumnGenerator async wrappers -------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_from_scratch_agenerate_from_scratch() -> None: + """FromScratchColumnGenerator.agenerate_from_scratch wraps sync correctly.""" + + class TestFromScratch(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({"val": list(range(num_records))}) + + gen = TestFromScratch(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate_from_scratch(3) + assert len(result) == 3 + assert list(result["val"]) == [0, 1, 2] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_from_scratch_agenerate_passes_copy() -> None: + """FromScratchColumnGenerator.agenerate passes df.copy() to thread.""" + original = lazy.pd.DataFrame({"col1": [1, 2, 3]}) + received_data: list[lazy.pd.DataFrame] = [] + + class TestFromScratch(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + received_data.append(data) + data["new_col"] = "added" + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame() + + gen = TestFromScratch(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate(original) + + # Original should not be mutated + assert "new_col" not in original.columns + assert "new_col" in result.columns + + +# -- ColumnGeneratorFullColumn async wrapper ---------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_full_column_agenerate_passes_copy() -> None: + """ColumnGeneratorFullColumn.agenerate passes df.copy() to thread.""" + original = lazy.pd.DataFrame({"col1": ["a", "b"]}) + + class TestFullCol(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + data["added"] = True + return data + + gen = TestFullCol(config=_make_expr_config(), resource_provider=_mock_provider()) + result = await gen.agenerate(original) + + assert "added" not in original.columns + assert "added" in result.columns + + +# -- SeedDatasetColumnGenerator is_stateful ----------------------------------- + + +def test_seed_dataset_is_stateful() -> None: + assert SeedDatasetColumnGenerator.is_stateful.fget is not None # property exists + # Can't instantiate without full setup, so check the class-level property + assert SeedDatasetColumnGenerator.is_stateful.fget(Mock()) is True + + +# -- CustomColumnGenerator agenerate branching -------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_sync_function() -> None: + """Sync custom function is wrapped in asyncio.to_thread via agenerate.""" + + @custom_column_generator() + def sync_fn(row: dict) -> dict: + row["sync_col"] = "hello" + return row + + config = CustomColumnConfig(name="sync_col", generator_function=sync_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + result = await gen.agenerate({"input": "val"}) + assert result["sync_col"] == "hello" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_function() -> None: + """Async custom function is called directly as coroutine.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + row["async_col"] = "async_hello" + return row + + config = CustomColumnConfig(name="async_col", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + result = await gen.agenerate({"input": "val"}) + assert result["async_col"] == "async_hello" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_full_column_wraps_in_thread() -> None: + """Full-column custom generator wraps in asyncio.to_thread with df.copy().""" + + @custom_column_generator() + def full_col_fn(df: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + df["fc_col"] = "batch" + return df + + config = CustomColumnConfig( + name="fc_col", + generator_function=full_col_fn, + generation_strategy=GenerationStrategy.FULL_COLUMN, + ) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + + original = lazy.pd.DataFrame({"input": [1, 2]}) + result = await gen.agenerate(original) + + # Should not mutate the original since we pass .copy() in agenerate + assert "fc_col" not in original.columns + assert "fc_col" in result.columns + + +# -- Existing generators still work unchanged ---------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_llm_completion_agenerate_still_works() -> None: + """Verify LLM completion generators still have working agenerate (from PR #280).""" + assert hasattr(ColumnGeneratorWithModelChatCompletion, "agenerate") + # The agenerate is a custom implementation, not the base default + assert ColumnGeneratorWithModelChatCompletion.agenerate is not ColumnGenerator.agenerate + + +# -- Async custom generator error path parity --------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_required_columns() -> None: + """Async custom generator raises on missing required_columns.""" + + @custom_column_generator(required_columns=["input"]) + async def async_fn(row: dict) -> dict: + row["result"] = row["input"].upper() + return row + + config = CustomColumnConfig(name="result", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="Missing required columns"): + await gen.agenerate({"other": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_output_column() -> None: + """Async custom generator raises when expected output column is missing.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + row["wrong"] = "value" + return row + + config = CustomColumnConfig(name="expected", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="did not create the expected column"): + await gen.agenerate({"input": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_missing_side_effect_column() -> None: + """Async custom generator raises when declared side_effect column is missing.""" + + @custom_column_generator(side_effect_columns=["secondary"]) + async def async_fn(row: dict) -> dict: + row["primary"] = 1 + return row + + config = CustomColumnConfig(name="primary", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="did not create declared side_effect_columns"): + await gen.agenerate({"input": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_allow_resize_invalid_list() -> None: + """Async custom generator with allow_resize rejects invalid non-dict list items.""" + + @custom_column_generator(required_columns=["x"]) + async def async_fn(row: dict) -> list: + return [1, 2] + + config = CustomColumnConfig( + name="out", + generator_function=async_fn, + allow_resize=True, + ) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="list elements must be dicts"): + await gen.agenerate({"x": 1}) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_custom_agenerate_async_wraps_exception() -> None: + """Async custom generator wraps user exceptions in CustomColumnGenerationError.""" + + @custom_column_generator() + async def async_fn(row: dict) -> dict: + raise ValueError("async boom") + + config = CustomColumnConfig(name="result", generator_function=async_fn) + gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) + with pytest.raises(CustomColumnGenerationError, match="Custom generator function failed"): + await gen.agenerate({"input": 1}) diff --git a/plans/346/async-generators-and-task-queue.md b/plans/346/async-generators-and-task-queue.md index 0ac4e7ba7..73ec28a80 100644 --- a/plans/346/async-generators-and-task-queue.md +++ b/plans/346/async-generators-and-task-queue.md @@ -228,17 +228,25 @@ graph. - [x] `ExecutionGraph` class: - Backing stores: `dict[str, set[str]]` column → upstream columns; `dict[str, GenerationStrategy]` column → generation strategy - - `upstream(column: str) -> set[str]` — direct dependencies of a column - - `downstream(column: str) -> set[str]` — columns that depend on this one (for error attribution) - - `strategy(column: str) -> GenerationStrategy` — cell-by-cell or full-column - - `topological_order() -> list[str]` — valid DAG execution order (used by scheduler and for validation) - - `critical_path() -> list[str]` — longest dependency chain (useful for ETA estimates) - - `task_count(num_records: int, buffer_size: int) -> dict[str, int]` — exact task count per + - `get_upstream_columns(column: str) -> set[str]` — direct dependencies of a column + - `get_downstream_columns(column: str) -> set[str]` — columns that depend on this one (for error attribution) + - `get_strategy(column: str) -> GenerationStrategy` — cell-by-cell or full-column + - `get_topological_order() -> list[str]` — valid DAG execution order (cached; used by scheduler and for validation) + - `get_longest_dependency_chain() -> list[str]` — longest dependency chain by column count (useful for ETA estimates) + - `get_root_columns() -> list[str]` — columns with no upstream deps, in topological order + - `split_upstream_by_strategy(column: str) -> tuple[list[str], list[str]]` — splits + upstream into (batch/full-column, cell-by-cell) groups; cached per column + - `compute_task_count(num_records: int, buffer_size: int) -> dict[str, int]` — exact task count per column before the run starts; cell-by-cell columns produce `num_records` tasks, full-column columns (including from-scratch generators, which report `FULL_COLUMN`) produce `ceil(num_records / buffer_size)` tasks + - `compute_cell_dependencies(column, row_group, row_index | None, row_group_size) -> list[SliceRef]` + — derives cell-level deps on demand from column-level DAG + strategy - `to_mermaid() -> str` — Mermaid diagram string; nodes are annotated with strategy type -- [x] `build_execution_graph(column_configs, strategies: dict[str, GenerationStrategy]) -> ExecutionGraph` utility: + - `columns` property — all column names in insertion order + - `add_column(name, strategy)` / `add_edge(upstream, downstream)` — low-level construction + - `set_side_effect(side_effect_col, producer)` / `resolve_side_effect(column) -> str` — side-effect mapping +- [x] `ExecutionGraph.create(column_configs, strategies)` classmethod factory: - Input: the ordered list of `ColumnConfigT` / `MultiColumnConfig`, plus a pre-computed strategy map (available from generators at builder init time via `get_generation_strategy()`) - For each config, read `config.required_columns` → set of upstream column names @@ -246,26 +254,42 @@ graph. and map them back to their producer column, so downstream references resolve correctly - For `MultiColumnConfig`, all sub-columns share the same dependencies - Validate: every required column must resolve to a known producer (including - registered side-effect outputs), and the graph must be acyclic -- [x] Unit tests for graph construction, validation, critical path, task count, and Mermaid output + registered side-effect outputs), and the graph must be acyclic (raises `DAGCircularDependencyError`) +- [x] Unit tests for graph construction, validation, longest chain, task count, cell deps, and Mermaid output **Files**: new module `engine/dataset_builders/utils/execution_graph.py`, tests ### Step 2: Completion Tracker -A lightweight structure tracking which (column, row_group, row_index) tuples are -done. Row indices are **local** to their row group (0-based within each group), -matching the buffer manager's per-row-group addressing. +A frontier-based tracker tracking which (column, row_group, row_index) tuples are +done and maintaining a ready-to-dispatch frontier. Row indices are **local** to their +row group (0-based within each group), matching the buffer manager's per-row-group addressing. - [x] `CompletionTracker` class: - Internal: `dict[int, dict[str, set[int]]]` mapping row_group → column → set of completed local row indices - - `mark_complete(column: str, row_group: int, row_index: int)` / `mark_batch_complete(column: str, row_group: int, row_group_size: int)` - - `is_ready(column: str, row_group: int, row_index: int, graph: ExecutionGraph) -> bool` — checks all upstream columns for that (row_group, row_index) - - `is_batch_ready(column: str, row_group: int, row_group_size: int, graph: ExecutionGraph) -> bool` — checks all rows in group - - `drop_row(row_group: int, row_index: int)` — marks row as dropped across all columns; - `get_ready_tasks` skips dropped rows, in-flight tasks for dropped rows are ignored on completion - - `is_row_group_complete(row_group: int, row_group_size: int, all_columns: list[str]) -> bool` — all non-dropped rows have all columns done; `row_group_size` is the original size, dropped rows (via `drop_row`) are excluded internally - - `get_ready_tasks(graph: ExecutionGraph, row_groups, dispatched: set[Task]) -> list[Task]` — yields all currently dispatchable tasks, excluding dropped rows and already-dispatched/in-flight tasks; reads `graph.strategy(column)` to determine task granularity per column + - `with_graph(graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> CompletionTracker` — + classmethod factory that creates a frontier-enabled tracker; seeds the frontier with root tasks + - `mark_cell_complete(column, row_group, row_index)` — marks a cell done, discards it + from the frontier, and calls `_enqueue_downstream` to add newly-ready tasks + - `mark_row_range_complete(column, row_group, row_group_size)` — marks an entire batch done, + validates row-group size consistency, and enqueues downstream + - `is_complete(ref: SliceRef) -> bool` — check if a single cell is complete + - `is_all_complete(cells: list[SliceRef]) -> bool` — check if all given cells/batches are complete + - `drop_row(row_group, row_index)` — marks row as dropped; removes cell tasks for that row + from the frontier; calls `_reevaluate_batch_tasks` since dropping a row may unblock + full-column downstream tasks + - `is_dropped(row_group, row_index) -> bool` + - `is_row_group_complete(row_group, row_group_size, all_columns) -> bool` — all non-dropped rows have all columns done + - `get_ready_tasks(dispatched: set[Task]) -> list[Task]` — returns all currently dispatchable + tasks from the frontier, excluding already-dispatched/in-flight tasks; O(frontier) not O(C × R) + - Internal frontier management: + - `_seed_frontier()` — populates frontier with root column tasks (from `graph.get_root_columns()`) + - `_enqueue_downstream(column, row_group, row_index | None)` — on completion, checks each + downstream column's readiness using `split_upstream_by_strategy`; adds ready tasks to frontier + - `_reevaluate_batch_tasks(row_group)` — after row drop, checks if any full-column tasks + became ready (all non-dropped rows now complete) + - Strategy validation: `mark_cell_complete` requires `CELL_BY_CELL`, `mark_row_range_complete` + requires `FULL_COLUMN`; mismatches raise `ValueError` - [x] No locks needed: all access is from the single asyncio event loop thread - [x] Unit tests @@ -273,61 +297,79 @@ matching the buffer manager's per-row-group addressing. ### Step 3: Task Model -Simple dataclass representing a unit of work. +Simple dataclasses representing units of work and cell-level references. -- [x] `Task` dataclass: +- [x] `SliceRef` dataclass (frozen, ordered): + - `column: str`, `row_group: int`, `row_index: int | None = None` + - Reference to a cell or full row group in the execution grid + - Used by `ExecutionGraph.compute_cell_dependencies()` and `CompletionTracker.is_complete()` +- [x] `Task` dataclass (frozen): - `column: str` - `row_group: int` - `row_index: int | None` (None for batch tasks) - `task_type: Literal["from_scratch", "cell", "batch", "pre_batch_processor", "post_batch_processor"]` -- [x] `TaskResult` with status, output, error info +- [x] `TaskResult` dataclass: + - `task: Task`, `status: Literal["success", "error"]`, `output: Any`, `error: Exception | None` + - `retryable: bool = False` — whether the failure can be retried by the salvage loop - [x] `TaskTrace` dataclass (only instantiated when tracing is enabled): - `column: str`, `row_group: int`, `row_index: int | None`, `task_type: str` - `dispatched_at: float` — `perf_counter()` when `create_task()` fires - `slot_acquired_at: float` — after execution semaphore acquired - `completed_at: float` — in `finally` block after generator returns - `status: str`, `error: str | None` + - `from_task(task: Task) -> TaskTrace` classmethod factory - [x] Hashable so we can track dispatched/pending sets +- [x] `DAGCircularDependencyError` in `errors.py` — raised by `ExecutionGraph.get_topological_order()` -**Files**: new module `engine/dataset_builders/utils/task_model.py` — must be its own module -since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference `Task`/`TaskResult`; -inlining would create import cycles. +**Files**: new module `engine/dataset_builders/utils/task_model.py` (+ `errors.py`) — must be its own +module since `CompletionTracker`, `AsyncTaskScheduler`, and the buffer manager all reference +`Task`/`TaskResult`/`SliceRef`; inlining would create import cycles. ### Step 4: Async Task Scheduler The core orchestrator that replaces `_run_batch` for the async path. - [ ] `AsyncTaskScheduler` class: - - Constructor takes: generators (by column name), `graph: ExecutionGraph`, completion tracker, row group definitions, concurrency limit (`async_scheduler_max_submitted_tasks`), row group semaphore (`async_max_concurrent_row_groups`), salvage config, error/result callbacks, `trace: bool = False` - - When `trace=True`, populates `scheduler.traces: list[TaskTrace]` (one record per task); otherwise no `TaskTrace` objects are created. See Profiling. + - Constructor takes: generators (by column name), `graph: ExecutionGraph`, row group + definitions (`list[tuple[int, int]]`), concurrency limit (`async_scheduler_max_submitted_tasks`), + row group semaphore (`async_max_concurrent_row_groups`), salvage config, error/result + callbacks, `trace: bool = False` + - Initializes `CompletionTracker.with_graph(graph, row_groups)` — the tracker seeds + its frontier with root tasks automatically + - When `trace=True`, populates `scheduler.traces: list[TaskTrace]` (one record per task, + created via `TaskTrace.from_task()`); otherwise no `TaskTrace` objects are created. See Profiling. - `async run()` — main loop: 1. Acquire the row group semaphore (`async_max_concurrent_row_groups`) before admitting a new row group's seed tasks. Dispatch `from_scratch` tasks, respecting `is_stateful`: stateful generators serialize per-instance (row group N's seed completes before N+1's seed starts for that generator); stateless generators dispatch all admitted row groups concurrently - 2. Loop: query `completion_tracker.get_ready_tasks()` → dispatch each via - `asyncio.create_task()` behind submission budget → on completion, update - tracker → repeat until all tasks done or early shutdown + 2. Loop: pull from `tracker.get_ready_tasks(dispatched)` → dispatch each via + `asyncio.create_task()` behind submission budget → on completion, call + `tracker.mark_cell_complete()` or `tracker.mark_row_range_complete()` (the tracker's + internal `_enqueue_downstream` auto-populates the frontier with newly-ready tasks) + → repeat until all tasks done or early shutdown 3. When ready queue drains, run salvage rounds over deferred retryable failures - (up to `async_salvage_max_rounds` rounds) - 4. After each row group completes: run post-batch processors, checkpoint + (up to `async_salvage_max_rounds` rounds); check `TaskResult.retryable` to classify + 4. After each row group completes (check via `tracker.is_row_group_complete()`): + run post-batch processors, checkpoint - Task dispatch follows the pattern from §4: acquire execution slot → prepare → release → await throttle (LLM only) → reacquire → execute + writeback → release - Admission control: never allow more than `async_scheduler_max_submitted_tasks` - tasks in submitted/running/waiting states; hold remaining ready tasks in the - scheduler queue until slots free up - - Error handling: classify failures as retryable vs non-retryable; retryable - go to deferred queue with backoff; same early-shutdown logic as - `AsyncConcurrentExecutor` (error rate threshold within sliding window) + tasks in submitted/running/waiting states; remove tasks from `dispatched` set on + completion; hold remaining ready tasks in the scheduler queue until slots free up + - Error handling: classify failures as retryable vs non-retryable (set `TaskResult.retryable`); + retryable go to deferred queue with backoff; non-retryable trigger `tracker.drop_row()` + which auto-removes cell tasks from frontier and re-evaluates batch readiness; + same early-shutdown logic as `AsyncConcurrentExecutor` (error rate threshold within sliding window) - Progress tracking: create one `ProgressTracker` per column for accounting (success/failure counts, rate, ETA), but suppress per-completion interval logs in async mode. A separate background coroutine (`asyncio.create_task`) emits a single consolidated summary line every 10 seconds across all active columns; it is cancelled once all tasks complete. See UX Considerations. - [ ] Use `asyncio.Event` to wake the scheduler when a task completes (avoids polling). - `Event` is sufficient — the scheduler resets it and re-checks ready tasks on each wake; - `Condition` would be needed only if waiting on a specific predicate, which the tracker + `Event` is sufficient — the scheduler resets it and re-checks `get_ready_tasks` on each wake; + `Condition` would be needed only if waiting on a specific predicate, which the frontier already handles. - [ ] Unit tests with mock generators @@ -352,38 +394,40 @@ This means sync-first generators (most built-ins, existing plugins) work unchang and async-first generators (new plugins doing native async I/O) only need to implement `agenerate()` without writing a redundant sync version. -- [ ] Add symmetric bridging on the base `ColumnGenerator`: +- [x] Add symmetric bridging on the base `ColumnGenerator`: - `agenerate()` default: `asyncio.to_thread(self.generate, data)` (already exists) - `generate()` default: call a safe sync runner helper that: - uses `asyncio.run()` if no loop is running in the current thread - otherwise submits to the background loop with `run_coroutine_threadsafe(...).result(timeout=...)` - Detect which one the subclass overrides to avoid infinite recursion -- [ ] Add `is_stateful` property to base `ColumnGenerator` (default `False`). + - **Note**: v1 uses ThreadPoolExecutor fallback instead of builder's background loop (available in PR 4) +- [x] Add `is_stateful` property to base `ColumnGenerator` (default `False`). Stateful generators are serialized per-instance by the scheduler. -- [ ] `ColumnGeneratorWithModelChatCompletion.agenerate` — already implemented (PR #280), no changes needed -- [ ] `FromScratchColumnGenerator`: add both async wrappers — `async agenerate_from_scratch(num_records) -> DataFrame` +- [x] `ColumnGeneratorWithModelChatCompletion.agenerate` — already implemented (PR #280), no changes needed +- [x] `FromScratchColumnGenerator`: add both async wrappers — `async agenerate_from_scratch(num_records) -> DataFrame` (wraps `generate_from_scratch` in `asyncio.to_thread`) and `async agenerate(data: DataFrame) -> DataFrame` (wraps `generate` in `asyncio.to_thread` with defensive `df.copy()`). Both are needed because the scheduler dispatches subclasses via either path depending on whether the buffer is empty. -- [ ] `ColumnGeneratorFullColumn`: add `async agenerate(data: DataFrame) -> DataFrame` — wraps sync in +- [x] `ColumnGeneratorFullColumn`: add `async agenerate(data: DataFrame) -> DataFrame` — wraps sync in `asyncio.to_thread` with defensive `df.copy()` (see Risks). This intentionally overrides the base `ColumnGenerator.agenerate(dict)` with a DataFrame-typed signature; the scheduler dispatches the correct variant based on generation strategy. -- [ ] `ExpressionColumnGenerator`: inherits full-column async wrapper -- [ ] `SamplerColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = False` -- [ ] `SeedDatasetColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = True` (maintains DuckDB batch reader cursor and leftover-row buffer) -- [ ] `ValidationColumnGenerator`: inherits full-column async wrapper. Note: for `REMOTE` validators +- [x] `ExpressionColumnGenerator`: inherits full-column async wrapper +- [x] `SamplerColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = False` +- [x] `SeedDatasetColumnGenerator`: inherits both wrappers from `FromScratchColumnGenerator`; no custom implementation needed. `is_stateful = True` (maintains DuckDB batch reader cursor and leftover-row buffer) +- [x] `ValidationColumnGenerator`: inherits full-column async wrapper. Note: for `REMOTE` validators with `max_parallel_requests > 1`, `generate()` internally uses `ConcurrentThreadExecutor`, so the async wrapper spawns a thread that itself spawns more threads — bypassing the scheduler's concurrency controls for those HTTP calls. Acceptable for v1 (see Follow-ups). -- [ ] `CustomColumnGenerator`: inherits directly from `ColumnGenerator` (not from +- [x] `CustomColumnGenerator`: inherits directly from `ColumnGenerator` (not from `ColumnGeneratorFullColumn`), so it does not automatically inherit the full-column async wrapper. Needs its own `agenerate` that branches on strategy: - `CELL_BY_CELL`: if the user function is a coroutine (`asyncio.iscoroutinefunction`), call it directly; otherwise wrap in `asyncio.to_thread` - `FULL_COLUMN`: wrap `generate(DataFrame)` in `asyncio.to_thread` with defensive `df.copy()` `is_stateful` defaults to `False`; custom implementations can override it. -- [ ] `ImageCellGenerator`, `EmbeddingCellGenerator`: add native `agenerate` using `model.agenerate_image` / `model.agenerate_text_embeddings` + - **Note**: uses `inspect.unwrap()` to detect async through the `@custom_column_generator` decorator wrapper +- [x] `ImageCellGenerator`, `EmbeddingCellGenerator`: add native `agenerate` using `model.agenerate_image` / `model.agenerate_text_embeddings` **Files**: `generators/base.py`, `generators/expression.py`, `generators/samplers.py`, `generators/seed_dataset.py`, `generators/image.py`, `generators/embedding.py`, tests @@ -410,9 +454,12 @@ Adapt `DatasetBatchManager` for concurrent row group processing. Wire the new scheduler into `ColumnWiseDatasetBuilder`. - [ ] New method `_build_async(generators, num_records, buffer_size, ...)`: - 1. Build `ExecutionGraph` from `self._column_configs` and generator strategies - 2. Partition rows into row groups - 3. Create `CompletionTracker`, `AsyncTaskScheduler` + 1. Build `ExecutionGraph.create(self._column_configs, strategies)` from configs and + generator strategies; catch `DAGCircularDependencyError` and `ValueError` and + re-raise as `DatasetGenerationError` with context + 2. Partition rows into row groups as `list[tuple[int, int]]` (rg_id, rg_size) + 3. Create `AsyncTaskScheduler` (which internally creates + `CompletionTracker.with_graph(graph, row_groups)`) 4. Run scheduler on the background event loop (reuse `_ensure_async_engine_loop()` from `dataset_builders/utils/async_concurrency.py` — already exists) 5. Scheduler handles checkpointing via callbacks @@ -431,33 +478,35 @@ Wire the new scheduler into `ColumnWiseDatasetBuilder`. Tests are added incrementally with each PR, not deferred to the end. -**PR 1 (foundation) — unit tests**: -- [x] Execution graph construction, validation, topological order, critical path +**PR 1 (foundation) — unit tests** (merged): +- [x] Execution graph construction, validation, `get_topological_order`, `get_longest_dependency_chain` - [x] Execution graph: side-effect output columns resolve correctly (e.g., column depending on `summary__trace` maps to a dependency on the `summary` generator) -- [x] Execution graph: `cell_dependencies` returns correct deps for cell-by-cell, +- [x] Execution graph: `compute_cell_dependencies` returns correct deps for cell-by-cell, full-column, and from-scratch columns -- [x] Execution graph: `task_count` and `to_mermaid` output -- [x] Completion tracker: `mark_complete`, `is_complete`, `all_complete` -- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete` -- [x] Task model: hashability, equality, TaskResult, TaskTrace +- [x] Execution graph: `compute_task_count`, `split_upstream_by_strategy`, and `to_mermaid` output +- [x] Completion tracker: `mark_cell_complete`, `mark_row_range_complete`, `is_complete`, `is_all_complete` +- [x] Completion tracker: frontier-based `get_ready_tasks` with `with_graph` initialization +- [x] Completion tracker: `drop_row`, `is_dropped`, `is_row_group_complete`, `_reevaluate_batch_tasks` +- [x] Task model: hashability, equality, TaskResult (including `retryable`), TaskTrace, SliceRef **PR 2 (generators) — unit tests**: -- [ ] Symmetric bridging: sync-only generator can be called via `agenerate` -- [ ] Symmetric bridging: async-only generator can be called via `generate` -- [ ] `is_stateful` defaults to `False`; `SeedDatasetColumnGenerator` returns `True` -- [ ] `FromScratchColumnGenerator.agenerate_from_scratch` wraps sync correctly -- [ ] `ColumnGeneratorFullColumn.agenerate` passes `df.copy()` to thread -- [ ] `CustomColumnGenerator.agenerate` detects coroutine functions and calls directly -- [ ] All existing generator tests pass unchanged (`make test`) +- [x] Symmetric bridging: sync-only generator can be called via `agenerate` +- [x] Symmetric bridging: async-only generator can be called via `generate` +- [x] `is_stateful` defaults to `False`; `SeedDatasetColumnGenerator` returns `True` +- [x] `FromScratchColumnGenerator.agenerate_from_scratch` wraps sync correctly +- [x] `ColumnGeneratorFullColumn.agenerate` passes `df.copy()` to thread +- [x] `CustomColumnGenerator.agenerate` detects coroutine functions and calls directly +- [x] All existing generator tests pass unchanged (`make test`) **PR 3 (scheduler + buffer) — unit tests with mock generators**: -- [ ] Scheduler dispatches from-scratch tasks first, then downstream as deps complete +- [ ] Scheduler dispatches root tasks first (from `tracker.get_ready_tasks`), + then downstream as deps complete (via tracker's `_enqueue_downstream`) - [ ] Stateful generator serializes across row groups; stateless runs concurrently -- [ ] Retry salvage: transient failure is retried and succeeds; - non-retryable failure drops immediately; retry budget exhaustion drops correctly -- [ ] Eager row-drop: failure on column B drops the row across all columns, - independent column C does not process the dropped row +- [ ] Retry salvage: transient failure (`TaskResult.retryable=True`) is retried and succeeds; + non-retryable failure triggers `tracker.drop_row()` immediately; retry budget exhaustion drops correctly +- [ ] Eager row-drop: failure on column B calls `tracker.drop_row()` which removes + cell tasks for that row from frontier; independent column C does not process the dropped row - [ ] Row-drop with in-flight full-column task: completed task may still compute dropped rows, but writeback is suppressed and row remains dropped - [ ] Bounded submission: submitted task count never exceeds @@ -489,9 +538,10 @@ The implementation steps map to 4 PRs that can be reviewed and merged independen Each PR is self-contained: it adds new modules with full test coverage but does not change existing behavior until the final integration PR. -### PR 1: Foundation (Steps 1 + 2 + 3) +### PR 1: Foundation (Steps 1 + 2 + 3) — MERGED as [#356](https://github.com/NVIDIA-NeMo/DataDesigner/pull/356) -**Scope**: `ExecutionGraph`, `CompletionTracker`, `Task`/`TaskResult`/`TaskTrace` dataclasses. +**Scope**: `ExecutionGraph`, `CompletionTracker`, `SliceRef`/`Task`/`TaskResult`/`TaskTrace` +dataclasses, `DAGCircularDependencyError`. All three are pure data structures with no side effects on the existing codebase. They live in new modules under `engine/dataset_builders/utils/` and are only imported @@ -500,16 +550,18 @@ by code introduced in later PRs. - `execution_graph.py` + tests - `completion_tracker.py` + tests - `task_model.py` + tests +- `errors.py` (`DAGCircularDependencyError`) **Why grouped**: the three are tightly coupled (the tracker takes the graph to resolve readiness, the task model is the unit of work for both), small individually, and have no external dependencies. Splitting them into 3 separate PRs would create review overhead without meaningful isolation benefit. -**What works after merge**: you can build an `ExecutionGraph` from any existing config, -inspect it (`topological_order`, `critical_path`, `task_count`, `to_mermaid`), query -cell-level dependencies, and track completion state — all in isolation, with full test -coverage. No runtime behavior changes. +**What works after merge**: you can build an `ExecutionGraph.create()` from any existing config, +inspect it (`get_topological_order`, `get_longest_dependency_chain`, `compute_task_count`, +`to_mermaid`), query cell-level dependencies via `compute_cell_dependencies()`, and track +completion state with the frontier-enabled `CompletionTracker.with_graph()` — all in +isolation, with full test coverage. No runtime behavior changes. **Can merge independently**: yes — no existing code imports these modules. @@ -540,13 +592,21 @@ Existing sync callers are unaffected. **Scope**: `AsyncTaskScheduler`, row group buffer manager. -- `async_scheduler.py` + tests (uses graph, tracker, and task model from PR 1) +- `async_scheduler.py` + tests (uses `ExecutionGraph.create()`, + `CompletionTracker.with_graph()`, `Task`, `TaskResult`, `TaskTrace`, `SliceRef` from PR 1) - Buffer manager extension in `dataset_batch_manager.py` + tests - Retry/salvage logic, progress consolidation, error handling -**Depends on**: PR 1 (imports `ExecutionGraph`, `CompletionTracker`, `Task`), PR 2 +**Depends on**: PR 1 (imports `ExecutionGraph`, `CompletionTracker`, `Task`, `SliceRef`), PR 2 (calls `agenerate` / `agenerate_from_scratch`, reads `is_stateful`). +**Key integration with PR 1's frontier model**: The scheduler initializes +`CompletionTracker.with_graph(graph, row_groups)` which auto-seeds the frontier with +root tasks. The main loop pulls from `tracker.get_ready_tasks(dispatched)`, and on task +completion calls `mark_cell_complete()` / `mark_row_range_complete()` which internally +enqueues newly-ready downstream tasks. On row drop, calls `tracker.drop_row()` which +removes frontier tasks and re-evaluates batch readiness. + **What works after merge**: the scheduler can be instantiated with mock generators and driven through its full lifecycle in tests — row group admission, dependency-driven dispatch, retry/salvage, row drops, checkpoint callbacks. The buffer manager supports @@ -829,11 +889,13 @@ mid-run loses at most one batch. **`ExecutionTraits` replaced by `GenerationStrategy` on the graph.** PR #269 attaches an `ExecutionTraits` flag enum (`CELL`, `BARRIER`, `ROW_STREAMABLE`) to each node. Since our graph is column-level, we store `GenerationStrategy` (cell-by-cell, full-column) directly -on each column node instead. From-scratch columns are identified by having no upstream -dependencies in the graph; the scheduler checks `can_generate_from_scratch` on the generator -instance to determine which method to call. This serves the same purpose as `ExecutionTraits` -— the scheduler and `CompletionTracker` use it to determine task granularity — without -needing typed node IDs or flag combinations. +on each column node instead (accessible via `get_strategy()`). From-scratch columns are +identified by having no upstream dependencies in the graph (via `get_root_columns()`); the +scheduler checks `can_generate_from_scratch` on the generator instance to determine which +method to call. The `split_upstream_by_strategy()` method provides cached separation of +upstream deps by strategy type, used by the tracker's frontier logic. This serves the same +purpose as `ExecutionTraits` — the scheduler and `CompletionTracker` use it to determine +task granularity — without needing typed node IDs or flag combinations. **`ROW_STREAMABLE` trait omitted.** PR #269 introduces `is_row_streamable` so full-column generators that process rows independently (e.g., `ExpressionColumnGenerator`) can be From 1696fb5c8e84117d56f1117bac4a633504f29a5a Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 9 Mar 2026 11:17:17 -0300 Subject: [PATCH 2/8] fix: avoid blocking caller on sync bridge timeout Use explicit pool lifecycle instead of context manager so that a TimeoutError releases the caller immediately via shutdown(wait=False) rather than blocking on pool.__exit__. --- .../engine/column_generators/generators/base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 5ce859432..6ade0faf8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -38,9 +38,16 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(asyncio.run, coro) + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = pool.submit(asyncio.run, coro) + try: return future.result(timeout=_SYNC_BRIDGE_TIMEOUT) + except concurrent.futures.TimeoutError: + pool.shutdown(wait=False, cancel_futures=True) + logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") + raise + else: + pool.shutdown(wait=True) class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): From 8b6e8b86be9e3339a56eb32679b1e7cf2e6959b7 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 9 Mar 2026 11:36:27 -0300 Subject: [PATCH 3/8] fix: widen agenerate type signature to match generate Add @overload declarations so the base agenerate accepts both dict and pd.DataFrame, mirroring the existing generate pattern. --- .../engine/column_generators/generators/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 6ade0faf8..36cc83fa8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -85,7 +85,13 @@ def generate(self, data: DataT) -> DataT: raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") return _run_coroutine_sync(self.agenerate(data)) - async def agenerate(self, data: dict) -> dict: + @overload + async def agenerate(self, data: dict) -> dict: ... + + @overload + async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: ... + + async def agenerate(self, data: DataT) -> DataT: """Async generate — delegates to sync ``generate()`` via thread pool. Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) From afbb60f991b55d87613090bd28b8c11bafa2588f Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 9 Mar 2026 16:37:49 -0300 Subject: [PATCH 4/8] fix: ensure pool shutdown on sync bridge success path The else clause after return was unreachable, leaking the ThreadPoolExecutor on every successful call. Capture the result first, shut down the pool, then return. --- .../engine/column_generators/generators/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 36cc83fa8..635013407 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -41,13 +41,13 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) future = pool.submit(asyncio.run, coro) try: - return future.result(timeout=_SYNC_BRIDGE_TIMEOUT) + result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) except concurrent.futures.TimeoutError: pool.shutdown(wait=False, cancel_futures=True) logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") raise - else: - pool.shutdown(wait=True) + pool.shutdown(wait=True) + return result class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): From b7ebd3fb6ea175c663a94494c14247f32dc3597d Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 9 Mar 2026 16:58:48 -0300 Subject: [PATCH 5/8] fix: use try/finally for pool shutdown in sync bridge Ensures ThreadPoolExecutor is shut down on all exit paths, including non-TimeoutError exceptions from the coroutine. --- .../engine/column_generators/generators/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 635013407..b47f4f932 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -40,13 +40,15 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: return asyncio.run(coro) pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) future = pool.submit(asyncio.run, coro) + timed_out = False try: result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) except concurrent.futures.TimeoutError: - pool.shutdown(wait=False, cancel_futures=True) + timed_out = True logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") raise - pool.shutdown(wait=True) + finally: + pool.shutdown(wait=not timed_out, cancel_futures=timed_out) return result From 3699311d6a9a8332990a7e117a60419c2ef7491d Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Mon, 9 Mar 2026 17:24:08 -0300 Subject: [PATCH 6/8] refactor: extract shared validation in ImageCellGenerator Move duplicated input validation and prompt rendering into _prepare_image_inputs, shared by generate and agenerate. --- .../column_generators/generators/image.py | 52 +++---------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py index 34fadb627..2060afb64 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/image.py @@ -32,72 +32,37 @@ def media_storage(self) -> MediaStorage: def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - def generate(self, data: dict) -> dict: - """Generate image(s) and optionally save to disk. - - Args: - data: Record data - - Returns: - Record with image path(s) (create mode) or base64 data (preview mode) added - """ + def _prepare_image_inputs(self, data: dict) -> tuple[str, list[dict] | None]: + """Validate inputs and render prompt for image generation.""" deserialized_record = deserialize_json_values(data) - - # Validate required columns missing_columns = list(set(self.config.required_columns) - set(data.keys())) if len(missing_columns) > 0: - error_msg = ( + raise ValueError( f"There was an error preparing the Jinja2 expression template. " f"The following columns {missing_columns} are missing!" ) - raise ValueError(error_msg) - - # Render prompt template self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) prompt = self.render_template(deserialized_record) - - # Validate prompt is non-empty if not prompt or not prompt.strip(): raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") - - # Process multi-modal context if provided multi_modal_context = self._build_multi_modal_context(deserialized_record) + return prompt, multi_modal_context - # Generate images (returns list of base64 strings) + def generate(self, data: dict) -> dict: + """Generate image(s) and optionally save to disk.""" + prompt, multi_modal_context = self._prepare_image_inputs(data) base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context) - - # Store via media storage (mode determines disk vs dataframe storage) - # Use column name as subfolder to organize images results = [ self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) for base64_image in base64_images ] data[self.config.name] = results - return data async def agenerate(self, data: dict) -> dict: """Native async generate using model.agenerate_image.""" - deserialized_record = deserialize_json_values(data) - - missing_columns = list(set(self.config.required_columns) - set(data.keys())) - if len(missing_columns) > 0: - raise ValueError( - f"There was an error preparing the Jinja2 expression template. " - f"The following columns {missing_columns} are missing!" - ) - - self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys())) - prompt = self.render_template(deserialized_record) - - if not prompt or not prompt.strip(): - raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty") - - multi_modal_context = self._build_multi_modal_context(deserialized_record) - + prompt, multi_modal_context = self._prepare_image_inputs(data) base64_images = await self.model.agenerate_image(prompt=prompt, multi_modal_context=multi_modal_context) - - # media_storage.save_base64_image is sync I/O — wrap in thread results = await asyncio.to_thread( lambda: [ self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name) @@ -105,5 +70,4 @@ async def agenerate(self, data: dict) -> dict: ] ) data[self.config.name] = results - return data From 011ec36f3d6050aab3b66224e17b8873bf3da3f5 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Tue, 10 Mar 2026 13:04:47 -0300 Subject: [PATCH 7/8] refactor: extract shared input prep in EmbeddingCellGenerator --- .../engine/column_generators/generators/embedding.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py index 88ac05f81..82eaf795b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/embedding.py @@ -27,17 +27,19 @@ class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]): def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - def generate(self, data: dict) -> dict: + def _prepare_embedding_inputs(self, data: dict) -> list[str]: deserialized_record = deserialize_json_values(data) - input_texts = parse_list_string(deserialized_record[self.config.target_column]) + return parse_list_string(deserialized_record[self.config.target_column]) + + def generate(self, data: dict) -> dict: + input_texts = self._prepare_embedding_inputs(data) embeddings = self.model.generate_text_embeddings(input_texts=input_texts) data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") return data async def agenerate(self, data: dict) -> dict: """Native async generate using model.agenerate_text_embeddings.""" - deserialized_record = deserialize_json_values(data) - input_texts = parse_list_string(deserialized_record[self.config.target_column]) + input_texts = self._prepare_embedding_inputs(data) embeddings = await self.model.agenerate_text_embeddings(input_texts=input_texts) data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") return data From 67212a3e626db6c98aede9d2d56d1cddcd8c1179 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 11 Mar 2026 11:50:30 -0300 Subject: [PATCH 8/8] address PR review feedback - add _is_overridden helper for symmetric generate/agenerate guards - move defensive .copy() into base agenerate, remove subclass overrides - re-raise as builtin TimeoutError for Python 3.10 compat - rename is_stateful to is_order_dependent with improved docstring - replace brittle .fget test with object.__new__ - add async tests for ImageCellGenerator and EmbeddingCellGenerator --- .../column_generators/generators/base.py | 30 ++++--- .../generators/seed_dataset.py | 2 +- .../generators/test_async_generators.py | 78 ++++++++++++++++--- 3 files changed, 84 insertions(+), 26 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index b47f4f932..4bb497a9d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -43,10 +43,10 @@ def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: timed_out = False try: result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) - except concurrent.futures.TimeoutError: + except concurrent.futures.TimeoutError as exc: timed_out = True logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") - raise + raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc finally: pool.shutdown(wait=not timed_out, cancel_futures=timed_out) return result @@ -58,14 +58,18 @@ def can_generate_from_scratch(self) -> bool: return False @property - def is_stateful(self) -> bool: - """Whether this generator maintains state across calls. + def is_order_dependent(self) -> bool: + """Whether this generator's output depends on prior row-group calls. - Stateful generators are serialized per-instance by the async scheduler - (row group N must complete before N+1 starts for that generator). + Example: SeedDatasetColumnGenerator tracks its position in the seed + dataset, so row group N must complete before N+1 starts. """ return False + def _is_overridden(self, method_name: str) -> bool: + """Check if a subclass has overridden a base ColumnGenerator method.""" + return getattr(type(self), method_name) is not getattr(ColumnGenerator, method_name) + @staticmethod @abstractmethod def get_generation_strategy() -> GenerationStrategy: ... @@ -83,7 +87,7 @@ def generate(self, data: DataT) -> DataT: implement ``agenerate()``. Raises ``NotImplementedError`` if neither ``generate()`` nor ``agenerate()`` is overridden. """ - if type(self).agenerate is ColumnGenerator.agenerate: + if not self._is_overridden("agenerate"): raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") return _run_coroutine_sync(self.agenerate(data)) @@ -99,7 +103,9 @@ async def agenerate(self, data: DataT) -> DataT: Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) should override this with a direct async implementation. """ - return await asyncio.to_thread(self.generate, data) + if not self._is_overridden("generate"): + raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") + return await asyncio.to_thread(self.generate, data.copy()) def log_pre_generation(self) -> None: """A shared method to log info before the generator's `generate` method is called. @@ -122,10 +128,6 @@ async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: """Async wrapper — wraps sync ``generate_from_scratch()`` in a thread.""" return await asyncio.to_thread(self.generate_from_scratch, num_records) - async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: - """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" - return await asyncio.to_thread(self.generate, data.copy()) - class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): @property @@ -213,7 +215,3 @@ def get_generation_strategy() -> GenerationStrategy: @abstractmethod def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... - - async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: - """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" - return await asyncio.to_thread(self.generate, data.copy()) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py index 4f1380736..a310ca3c8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -30,7 +30,7 @@ def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.FULL_COLUMN @property - def is_stateful(self) -> bool: + def is_order_dependent(self) -> bool: return True @property diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py index 4d0574daa..7eff29fec 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_async_generators.py @@ -3,15 +3,17 @@ from __future__ import annotations -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch import pytest import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import ( CustomColumnConfig, + EmbeddingColumnConfig, ExpressionColumnConfig, GenerationStrategy, + ImageColumnConfig, ) from data_designer.config.custom_column import custom_column_generator from data_designer.engine.column_generators.generators.base import ( @@ -21,6 +23,11 @@ _run_coroutine_sync, ) from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator +from data_designer.engine.column_generators.generators.embedding import ( + EmbeddingCellGenerator, + EmbeddingGenerationResult, +) +from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.column_generators.generators.llm_completion import ( ColumnGeneratorWithModelChatCompletion, ) @@ -63,10 +70,10 @@ async def double(x: int) -> int: assert result == 10 -# -- is_stateful default ---------------------------------------------------- +# -- is_order_dependent default ---------------------------------------------------- -def test_is_stateful_default_false() -> None: +def test_is_order_dependent_default_false() -> None: class SyncGen(ColumnGenerator[ExpressionColumnConfig]): @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -76,7 +83,7 @@ def generate(self, data: dict) -> dict: return data gen = SyncGen(config=_make_expr_config(), resource_provider=_mock_provider()) - assert gen.is_stateful is False + assert gen.is_order_dependent is False # -- Symmetric bridging: sync-only generator called via agenerate ----------- @@ -136,6 +143,20 @@ def get_generation_strategy() -> GenerationStrategy: gen.generate({"col1": "x"}) +@pytest.mark.asyncio(loop_scope="session") +async def test_neither_generate_nor_agenerate_raises_from_async() -> None: + """If neither is overridden, agenerate() raises directly without thread bounce.""" + + class BareGen(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + gen = BareGen(config=_make_expr_config(), resource_provider=_mock_provider()) + with pytest.raises(NotImplementedError, match="must implement either"): + await gen.agenerate({"col1": "x"}) + + # -- FromScratchColumnGenerator async wrappers -------------------------------- @@ -207,13 +228,12 @@ def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: assert "added" in result.columns -# -- SeedDatasetColumnGenerator is_stateful ----------------------------------- +# -- SeedDatasetColumnGenerator is_order_dependent ----------------------------------- -def test_seed_dataset_is_stateful() -> None: - assert SeedDatasetColumnGenerator.is_stateful.fget is not None # property exists - # Can't instantiate without full setup, so check the class-level property - assert SeedDatasetColumnGenerator.is_stateful.fget(Mock()) is True +def test_seed_dataset_is_order_dependent() -> None: + gen = object.__new__(SeedDatasetColumnGenerator) + assert gen.is_order_dependent is True # -- CustomColumnGenerator agenerate branching -------------------------------- @@ -362,3 +382,43 @@ async def async_fn(row: dict) -> dict: gen = CustomColumnGenerator(config=config, resource_provider=_mock_provider()) with pytest.raises(CustomColumnGenerationError, match="Custom generator function failed"): await gen.agenerate({"input": 1}) + + +# -- ImageCellGenerator async ------------------------------------------------ + + +@pytest.mark.asyncio(loop_scope="session") +async def test_image_agenerate(stub_resource_provider: Mock) -> None: + """ImageCellGenerator.agenerate calls model.agenerate_image.""" + mock_storage = Mock() + mock_storage.save_base64_image.side_effect = ["images/img1.png", "images/img2.png"] + stub_resource_provider.artifact_storage.media_storage = mock_storage + + config = ImageColumnConfig(name="test_image", prompt="A {{ style }} image", model_alias="test_model") + gen = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + + with patch.object(gen, "model") as mock_model: + mock_model.agenerate_image = AsyncMock(return_value=["b64_1", "b64_2"]) + result = await gen.agenerate({"style": "photorealistic"}) + + assert result["test_image"] == ["images/img1.png", "images/img2.png"] + mock_model.agenerate_image.assert_awaited_once() + + +# -- EmbeddingCellGenerator async -------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_embedding_agenerate(stub_resource_provider: Mock) -> None: + """EmbeddingCellGenerator.agenerate calls model.agenerate_text_embeddings.""" + config = EmbeddingColumnConfig(name="test_emb", target_column="text", model_alias="test_model") + gen = EmbeddingCellGenerator(config=config, resource_provider=stub_resource_provider) + + stub_embeddings = [[0.1, 0.2], [0.3, 0.4]] + with patch.object(gen, "model") as mock_model: + mock_model.agenerate_text_embeddings = AsyncMock(return_value=stub_embeddings) + result = await gen.agenerate({"text": "['hello', 'world']"}) + + expected = EmbeddingGenerationResult(embeddings=stub_embeddings).model_dump(mode="json") + assert result["test_emb"] == expected + mock_model.agenerate_text_embeddings.assert_awaited_once_with(input_texts=["hello", "world"])