Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import functools
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload

from data_designer.config.column_configs import GenerationStrategy
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT

_T = TypeVar("_T")

_SYNC_BRIDGE_TIMEOUT = 300

if TYPE_CHECKING:
import pandas as pd

Expand All @@ -23,33 +28,84 @@
logger = logging.getLogger(__name__)


def _run_coroutine_sync(coro: Coroutine[Any, Any, _T]) -> _T:
"""Run an async coroutine from sync context.

- No running event loop → ``asyncio.run(coro)``
- Running event loop (e.g. notebook/service) → run in a background thread
"""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = pool.submit(asyncio.run, coro)
timed_out = False
try:
result = future.result(timeout=_SYNC_BRIDGE_TIMEOUT)
except concurrent.futures.TimeoutError as exc:
timed_out = True
logger.warning(f"⚠️ Sync bridge timed out after {_SYNC_BRIDGE_TIMEOUT}s; background thread still running")
raise TimeoutError(f"_run_coroutine_sync timed out after {_SYNC_BRIDGE_TIMEOUT}s") from exc
finally:
pool.shutdown(wait=not timed_out, cancel_futures=timed_out)
return result


class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
@property
def can_generate_from_scratch(self) -> bool:
return False

@property
def is_order_dependent(self) -> bool:
"""Whether this generator's output depends on prior row-group calls.

Example: SeedDatasetColumnGenerator tracks its position in the seed
dataset, so row group N must complete before N+1 starts.
"""
return False

def _is_overridden(self, method_name: str) -> bool:
"""Check if a subclass has overridden a base ColumnGenerator method."""
return getattr(type(self), method_name) is not getattr(ColumnGenerator, method_name)

@staticmethod
@abstractmethod
def get_generation_strategy() -> GenerationStrategy: ...

@overload
@abstractmethod
def generate(self, data: dict) -> dict: ...

@overload
@abstractmethod
def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...

@abstractmethod
def generate(self, data: DataT) -> DataT: ...
def generate(self, data: DataT) -> DataT:
"""Sync generate — overridden by most concrete generators.

Default bridges to ``agenerate()`` for async-first subclasses that only
implement ``agenerate()``. Raises ``NotImplementedError`` if neither
``generate()`` nor ``agenerate()`` is overridden.
"""
if not self._is_overridden("agenerate"):
raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()")
return _run_coroutine_sync(self.agenerate(data))

async def agenerate(self, data: dict) -> dict:
"""Async fallback — delegates to sync generate via thread pool.
@overload
async def agenerate(self, data: dict) -> dict: ...

@overload
async def agenerate(self, data: pd.DataFrame) -> pd.DataFrame: ...

async def agenerate(self, data: DataT) -> DataT:
"""Async generate — delegates to sync ``generate()`` via thread pool.

Subclasses with native async support (e.g. ColumnGeneratorWithModelChatCompletion)
should override this with a direct async implementation.
"""
return await asyncio.to_thread(self.generate, data)
if not self._is_overridden("generate"):
raise NotImplementedError(f"{type(self).__name__} must implement either generate() or agenerate()")
return await asyncio.to_thread(self.generate, data.copy())

def log_pre_generation(self) -> None:
"""A shared method to log info before the generator's `generate` method is called.
Expand All @@ -68,6 +124,10 @@ def can_generate_from_scratch(self) -> bool:
@abstractmethod
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...

async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame:
"""Async wrapper — wraps sync ``generate_from_scratch()`` in a thread."""
return await asyncio.to_thread(self.generate_from_scratch, num_records)


class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC):
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
import inspect
import logging
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -65,12 +66,57 @@ def generate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict

return self._generate(data, is_dataframe)

async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | list[dict]:
"""Async generate — branches on strategy and detects coroutine functions."""
is_full_column = self.config.generation_strategy == GenerationStrategy.FULL_COLUMN
if is_full_column:
return await asyncio.to_thread(self.generate, data.copy())
# The @custom_column_generator decorator wraps the user function in a sync
# wrapper, so we must unwrap to detect async functions.
Comment on lines +74 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol meant fancy stuff here 🙃

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the decorator wrapping forces our hand here - inspect.unwrap is the cleanest way to peek through to the original async function.

fn_unwrapped = inspect.unwrap(self.config.generator_function)
if asyncio.iscoroutinefunction(fn_unwrapped):
missing = set(self.config.required_columns) - set(data.keys())
if missing:
raise CustomColumnGenerationError(
f"Missing required columns for custom generator '{self.config.name}': {sorted(missing)}"
)
keys_before = set(data.keys())

try:
result = await self._ainvoke_generator_function(data)
except CustomColumnGenerationError:
raise
except Exception as e:
logger.warning(
f"⚠️ Custom generator function {self.config.generator_function.__name__!r} "
f"failed for column '{self.config.name}'. This record will be skipped.\n{e}"
)
raise CustomColumnGenerationError(
f"Custom generator function failed for column '{self.config.name}': {e}"
) from e

return self._postprocess_result(result, is_dataframe=False, keys_before=keys_before)
return await asyncio.to_thread(self.generate, data)

async def _ainvoke_generator_function(self, data: dict) -> dict | pd.DataFrame:
"""Invoke an async user generator function with appropriate arguments.

The @custom_column_generator decorator's sync wrapper returns a coroutine
when the original function is async, so we await the wrapper's return value.
"""
params = self._get_validated_params()
fn = self.config.generator_function
if len(params) == 1:
return await fn(data)
elif len(params) == 2:
return await fn(data, self.config.generator_params)
else:
models = self._build_models_dict()
return await fn(data, self.config.generator_params, models)

def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.DataFrame | list[dict]:
"""Unified generation logic for both strategies."""
# Get columns/keys using unified accessor
get_keys = (lambda d: set(d.columns)) if is_dataframe else (lambda d: set(d.keys()))
expected_type = lazy.pd.DataFrame if is_dataframe else dict
type_name = "DataFrame" if is_dataframe else "dict"

# Check required columns
missing = set(self.config.required_columns) - get_keys(data)
Expand All @@ -96,6 +142,15 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
f"Custom generator function failed for column '{self.config.name}': {e}"
) from e

return self._postprocess_result(result, is_dataframe, keys_before)

def _postprocess_result(
self,
result: dict | pd.DataFrame | list[dict],
is_dataframe: bool,
keys_before: set[str],
) -> dict | pd.DataFrame | list[dict]:
"""Validate type and output columns of a generation result."""
# Cell-by-cell with allow_resize: accept dict or list[dict]
if not is_dataframe and self.config.allow_resize:
if isinstance(result, dict):
Expand All @@ -113,6 +168,8 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
)

# Validate return type for non-resize paths
expected_type = lazy.pd.DataFrame if is_dataframe else dict
type_name = "DataFrame" if is_dataframe else "dict"
if not isinstance(result, expected_type):
raise CustomColumnGenerationError(
f"Custom generator for column '{self.config.name}' must return a {type_name}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,19 @@ class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL

def generate(self, data: dict) -> dict:
def _prepare_embedding_inputs(self, data: dict) -> list[str]:
deserialized_record = deserialize_json_values(data)
input_texts = parse_list_string(deserialized_record[self.config.target_column])
return parse_list_string(deserialized_record[self.config.target_column])

def generate(self, data: dict) -> dict:
input_texts = self._prepare_embedding_inputs(data)
embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
return data

async def agenerate(self, data: dict) -> dict:
"""Native async generate using model.agenerate_text_embeddings."""
input_texts = self._prepare_embedding_inputs(data)
embeddings = await self.model.agenerate_text_embeddings(input_texts=input_texts)
data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
return data
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

from data_designer.config.column_configs import ImageColumnConfig
Expand Down Expand Up @@ -31,46 +32,42 @@ def media_storage(self) -> MediaStorage:
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.CELL_BY_CELL

def generate(self, data: dict) -> dict:
"""Generate image(s) and optionally save to disk.

Args:
data: Record data

Returns:
Record with image path(s) (create mode) or base64 data (preview mode) added
"""
def _prepare_image_inputs(self, data: dict) -> tuple[str, list[dict] | None]:
"""Validate inputs and render prompt for image generation."""
deserialized_record = deserialize_json_values(data)

# Validate required columns
missing_columns = list(set(self.config.required_columns) - set(data.keys()))
if len(missing_columns) > 0:
error_msg = (
raise ValueError(
f"There was an error preparing the Jinja2 expression template. "
f"The following columns {missing_columns} are missing!"
)
raise ValueError(error_msg)

# Render prompt template
self.prepare_jinja2_template_renderer(self.config.prompt, list(deserialized_record.keys()))
prompt = self.render_template(deserialized_record)

# Validate prompt is non-empty
if not prompt or not prompt.strip():
raise ValueError(f"Rendered prompt for column {self.config.name!r} is empty")

# Process multi-modal context if provided
multi_modal_context = self._build_multi_modal_context(deserialized_record)
return prompt, multi_modal_context

# Generate images (returns list of base64 strings)
def generate(self, data: dict) -> dict:
"""Generate image(s) and optionally save to disk."""
prompt, multi_modal_context = self._prepare_image_inputs(data)
base64_images = self.model.generate_image(prompt=prompt, multi_modal_context=multi_modal_context)

# Store via media storage (mode determines disk vs dataframe storage)
# Use column name as subfolder to organize images
results = [
self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name)
for base64_image in base64_images
]
data[self.config.name] = results
return data

async def agenerate(self, data: dict) -> dict:
"""Native async generate using model.agenerate_image."""
prompt, multi_modal_context = self._prepare_image_inputs(data)
base64_images = await self.model.agenerate_image(prompt=prompt, multi_modal_context=multi_modal_context)
results = await asyncio.to_thread(
lambda: [
self.media_storage.save_base64_image(base64_image, subfolder_name=self.config.name)
for base64_image in base64_images
]
)
data[self.config.name] = results
return data
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColu
def get_generation_strategy() -> GenerationStrategy:
return GenerationStrategy.FULL_COLUMN

@property
def is_order_dependent(self) -> bool:
return True

@property
def num_records_sampled(self) -> int:
return self._num_records_sampled
Expand Down
Loading
Loading