-
Notifications
You must be signed in to change notification settings - Fork 157
feat: add async generator migration with symmetric bridging and statefulness #378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
99750d6
64c04a8
1696fb5
8b6e8b8
e9852c3
afbb60f
b7ebd3f
3699311
011ec36
67212a3
1928b32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,15 +4,20 @@ | |
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import concurrent.futures | ||
| import functools | ||
| import logging | ||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING, Any, overload | ||
| from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload | ||
|
|
||
| from data_designer.config.column_configs import GenerationStrategy | ||
| from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT | ||
| from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT | ||
|
|
||
| _T = TypeVar("_T") | ||
|
|
||
| _SYNC_BRIDGE_TIMEOUT = 300 | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pandas as pd | ||
|
|
||
|
|
@@ -23,28 +28,73 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T: | ||
| """Run an async coroutine from sync context. | ||
|
|
||
| - No running event loop → ``asyncio.run(coro)`` | ||
| - Running event loop (e.g. notebook/service) → run in a background thread | ||
| """ | ||
| try: | ||
| asyncio.get_running_loop() | ||
| except RuntimeError: | ||
| return asyncio.run(coro) | ||
| pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) | ||
| future = pool.submit(asyncio.run, coro) | ||
| timed_out = False | ||
| try: | ||
| result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT) | ||
| except concurrent.futures.TimeoutError: | ||
| timed_out = True | ||
| logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running") | ||
| raise | ||
| finally: | ||
| pool.shutdown(wait=not timed_out, cancel_futures=timed_out) | ||
| return result | ||
|
andreatgretel marked this conversation as resolved.
|
||
|
|
||
|
|
||
| class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): | ||
| @property | ||
| def can_generate_from_scratch(self) -> bool: | ||
| return False | ||
|
|
||
| @property | ||
| def is_stateful(self) -> bool: | ||
| """Whether this generator maintains state across calls. | ||
|
|
||
| Stateful generators are serialized per-instance by the async scheduler | ||
| (row group N must complete before N+1 starts for that generator). | ||
| """ | ||
| return False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still trying to fully understand this property. I think we want the "stateful" part in there because this is for columns like the seed column, which needs to remember where it is at in the generation process – is that right? I think the second part of the docstring is a bit hard to follow (might be just me, though).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nabinchha renamed to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @johnnygreco yeah exactly - the seed column needs to remember where it is in the dataset. renamed to |
||
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def get_generation_strategy() -> GenerationStrategy: ... | ||
|
|
||
| @overload | ||
| @abstractmethod | ||
| def generate(self, data: dict) -> dict: ... | ||
|
|
||
| @overload | ||
| @abstractmethod | ||
| def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... | ||
|
|
||
| @abstractmethod | ||
| def generate(self, data: DataT) -> DataT: ... | ||
| def generate(self, data: DataT) -> DataT: | ||
|
andreatgretel marked this conversation as resolved.
|
||
| """Sync generate — overridden by most concrete generators. | ||
|
|
||
| Default bridges to ``agenerate()`` for async-first subclasses that only | ||
| implement ``agenerate()``. Raises ``NotImplementedError`` if neither | ||
| ``generate()`` nor ``agenerate()`` is overridden. | ||
| """ | ||
| if type(self).agenerate is ColumnGenerator.agenerate: | ||
| raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()") | ||
| return _run_coroutine_sync(self.agenerate(data)) | ||
|
|
||
| @overload | ||
| async def agenerate(self, data: dict) -> dict: ... | ||
|
|
||
| async def agenerate(self, data: dict) -> dict: | ||
| """Async fallback — delegates to sync generate via thread pool. | ||
| @overload | ||
| async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: ... | ||
|
|
||
| async def agenerate(self, data: DataT) -> DataT: | ||
| """Async generate — delegates to sync ``generate()`` via thread pool. | ||
|
|
||
| Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion) | ||
| should override this with a direct async implementation. | ||
|
|
@@ -68,6 +118,14 @@ def can_generate_from_scratch(self) -> bool: | |
| @abstractmethod | ||
| def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... | ||
|
|
||
| async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: | ||
| """Async wrapper — wraps sync ``generate_from_scratch()`` in a thread.""" | ||
| return await asyncio.to_thread(self.generate_from_scratch, num_records) | ||
|
|
||
| async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: | ||
| """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" | ||
| return await asyncio.to_thread(self.generate, data.copy()) | ||
|
andreatgretel marked this conversation as resolved.
Outdated
andreatgretel marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): | ||
| @property | ||
|
|
@@ -155,3 +213,7 @@ def get_generation_strategy() -> GenerationStrategy: | |
|
|
||
| @abstractmethod | ||
| def generate(self, data: pd.DataFrame) -> pd.DataFrame: ... | ||
|
|
||
| async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: | ||
|
andreatgretel marked this conversation as resolved.
Outdated
|
||
| """Async wrapper — wraps sync ``generate()`` in a thread with defensive copy.""" | ||
| return await asyncio.to_thread(self.generate, data.copy()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import inspect | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
@@ -65,12 +66,57 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict | |
|
|
||
| return self._generate(data, is_dataframe) | ||
|
|
||
| async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]: | ||
| """Async generate — branches on strategy and detects coroutine functions.""" | ||
| is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN | ||
| if is_full_column: | ||
| return await asyncio.to_thread(self.generate, data.copy()) | ||
| # The @custom_column_generator decorator wraps the user function in a sync | ||
| # wrapper, so we must unwrap to detect async functions. | ||
|
Comment on lines
+74
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lol meant fancy stuff here 🙃
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah the decorator wrapping forces our hand here - |
||
| fn_unwrapped = inspect.unwrap(self.config.generator_function) | ||
| if asyncio.iscoroutinefunction(fn_unwrapped): | ||
| missing = set(self.config.required_columns) - set(data.keys()) | ||
| if missing: | ||
| raise CustomColumnGenerationError( | ||
| f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}" | ||
| ) | ||
| keys_before = set(data.keys()) | ||
|
|
||
| try: | ||
| result = await self._ainvoke_generator_function(data) | ||
| except CustomColumnGenerationError: | ||
| raise | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"⚠️ Custom generator function {self.config.generator_function.__name__!r} " | ||
| f"failed for column '{self.config.name}'. This record will be skipped.\n{e}" | ||
| ) | ||
| raise CustomColumnGenerationError( | ||
| f"Custom generator function failed for column '{self.config.name}': {e}" | ||
| ) from e | ||
|
|
||
| return self._postprocess_result(result, is_dataframe=False, keys_before=keys_before) | ||
| return await asyncio.to_thread(self.generate, data) | ||
|
andreatgretel marked this conversation as resolved.
andreatgretel marked this conversation as resolved.
|
||
|
|
||
| async def _ainvoke_generator_function(self, data: dict) -> dict | pd.DataFrame: | ||
| """Invoke an async user generator function with appropriate arguments. | ||
|
|
||
| The @custom_column_generator decorator's sync wrapper returns a coroutine | ||
| when the original function is async, so we await the wrapper's return value. | ||
| """ | ||
| params = self._get_validated_params() | ||
| fn = self.config.generator_function | ||
| if len(params) == 1: | ||
| return await fn(data) | ||
| elif len(params) == 2: | ||
| return await fn(data, self.config.generator_params) | ||
| else: | ||
| models = self._build_models_dict() | ||
| return await fn(data, self.config.generator_params, models) | ||
|
|
||
| def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]: | ||
| """Unified generation logic for both strategies.""" | ||
| # Get columns/keys using unified accessor | ||
| get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys())) | ||
| expected_type = lazy.pd.DataFrame if is_dataframe else dict | ||
| type_name = "DataFrame" if is_dataframe else "dict" | ||
|
|
||
| # Check required columns | ||
| missing = set(self.config.required_columns) - get_keys(data) | ||
|
|
@@ -96,6 +142,15 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. | |
| f"Custom generator function failed for column '{self.config.name}': {e}" | ||
| ) from e | ||
|
|
||
| return self._postprocess_result(result, is_dataframe, keys_before) | ||
|
|
||
| def _postprocess_result( | ||
| self, | ||
| result: dict | pd.DataFrame | list[dict], | ||
| is_dataframe: bool, | ||
| keys_before: set[str], | ||
| ) -> dict | pd.DataFrame | list[dict]: | ||
| """Validate type and output columns of a generation result.""" | ||
| # Cell-by-cell with allow_resize: accept dict or list[dict] | ||
| if not is_dataframe and self.config.allow_resize: | ||
| if isinstance(result, dict): | ||
|
|
@@ -113,6 +168,8 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd. | |
| ) | ||
|
|
||
| # Validate return type for non-resize paths | ||
| expected_type = lazy.pd.DataFrame if is_dataframe else dict | ||
| type_name = "DataFrame" if is_dataframe else "dict" | ||
| if not isinstance(result, expected_type): | ||
| raise CustomColumnGenerationError( | ||
| f"Custom generator for column '{self.config.name}' must return a {type_name}, " | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.