Skip to content

Commit a87f98b

Browse files
eric-tramelclaude
andcommitted
feat: add async image generation (agenerate_image)
Add agenerate_image(), _agenerate_image_chat_completion(), and _agenerate_image_diffusion() async methods mirroring the sync generate_image() added in #317. The chat completion path uses acompletion(), the diffusion path uses router.aimage_generation(). Includes 5 new tests covering both paths, error cases, and usage tracking. Also fixes F821 lint errors for type annotations. Co-Authored-By: Remi <noreply@anthropic.com>
1 parent c0854f1 commit a87f98b

2 files changed

Lines changed: 272 additions & 0 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/models/facade.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,3 +729,169 @@ async def agenerate(
729729
)
730730

731731
return output_obj, messages
732+
733+
@acatch_llm_exceptions
734+
async def agenerate_image(
735+
self,
736+
prompt: str,
737+
multi_modal_context: list[dict[str, Any]] | None = None,
738+
skip_usage_tracking: bool = False,
739+
**kwargs: Any,
740+
) -> list[str]:
741+
"""Async version of generate_image. Generate image(s) and return base64-encoded data.
742+
743+
Automatically detects the appropriate API based on model name:
744+
- Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API
745+
- All other models → chat/completions API (default)
746+
747+
Both paths return base64-encoded image data. If the API returns multiple images,
748+
all are returned in the list.
749+
750+
Args:
751+
prompt: The prompt for image generation
752+
multi_modal_context: Optional list of image contexts for multi-modal generation.
753+
Only used with autoregressive models via chat completions API.
754+
skip_usage_tracking: Whether to skip usage tracking
755+
**kwargs: Additional arguments to pass to the model (including n=number of images)
756+
757+
Returns:
758+
List of base64-encoded image strings (without data URI prefix)
759+
760+
Raises:
761+
ImageGenerationError: If image generation fails or returns invalid data
762+
"""
763+
logger.debug(
764+
f"Generating image with model {self.model_name!r}...",
765+
extra={"model": self.model_name, "prompt": prompt},
766+
)
767+
768+
# Auto-detect API type based on model name
769+
if is_image_diffusion_model(self.model_name):
770+
images = await self._agenerate_image_diffusion(prompt, skip_usage_tracking, **kwargs)
771+
else:
772+
images = await self._agenerate_image_chat_completion(
773+
prompt, multi_modal_context, skip_usage_tracking, **kwargs
774+
)
775+
776+
# Track image usage
777+
if not skip_usage_tracking and len(images) > 0:
778+
self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images)))
779+
780+
return images
781+
782+
async def _agenerate_image_chat_completion(
783+
self,
784+
prompt: str,
785+
multi_modal_context: list[dict[str, Any]] | None = None,
786+
skip_usage_tracking: bool = False,
787+
**kwargs: Any,
788+
) -> list[str]:
789+
"""Async version of _generate_image_chat_completion.
790+
791+
Generate image(s) using autoregressive model via chat completions API.
792+
793+
Args:
794+
prompt: The prompt for image generation
795+
multi_modal_context: Optional list of image contexts for multi-modal generation
796+
skip_usage_tracking: Whether to skip usage tracking
797+
**kwargs: Additional arguments to pass to the model
798+
799+
Returns:
800+
List of base64-encoded image strings
801+
"""
802+
messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context)
803+
804+
response = None
805+
try:
806+
response = await self.acompletion(
807+
messages=messages,
808+
skip_usage_tracking=skip_usage_tracking,
809+
**kwargs,
810+
)
811+
812+
logger.debug(
813+
f"Received image(s) from autoregressive model {self.model_name!r}",
814+
extra={"model": self.model_name, "response": response},
815+
)
816+
817+
# Validate response structure
818+
if not response.choices or len(response.choices) == 0:
819+
raise ImageGenerationError("Image generation response missing choices")
820+
821+
message = response.choices[0].message
822+
images = []
823+
824+
# Extract base64 from images attribute (primary path)
825+
if hasattr(message, "images") and message.images:
826+
for image in message.images:
827+
# Handle different response formats
828+
if isinstance(image, dict) and "image_url" in image:
829+
image_url = image["image_url"]
830+
831+
if isinstance(image_url, dict) and "url" in image_url:
832+
if (b64 := _try_extract_base64(image_url["url"])) is not None:
833+
images.append(b64)
834+
elif isinstance(image_url, str):
835+
if (b64 := _try_extract_base64(image_url)) is not None:
836+
images.append(b64)
837+
# Fallback: treat as base64 string
838+
elif isinstance(image, str):
839+
if (b64 := _try_extract_base64(image)) is not None:
840+
images.append(b64)
841+
842+
# Fallback: check content field if it looks like image data
843+
if not images:
844+
content = message.content or ""
845+
if content and (content.startswith("data:image/") or is_base64_image(content)):
846+
if (b64 := _try_extract_base64(content)) is not None:
847+
images.append(b64)
848+
849+
if not images:
850+
raise ImageGenerationError("No image data found in image generation response")
851+
852+
return images
853+
854+
except Exception:
855+
raise
856+
857+
async def _agenerate_image_diffusion(
858+
self, prompt: str, skip_usage_tracking: bool = False, **kwargs: Any
859+
) -> list[str]:
860+
"""Async version of _generate_image_diffusion.
861+
862+
Generate image(s) using diffusion model via image_generation API.
863+
864+
Always returns base64. If the API returns URLs instead of inline base64,
865+
the images are downloaded and converted automatically.
866+
867+
Returns:
868+
List of base64-encoded image strings
869+
"""
870+
kwargs = self.consolidate_kwargs(**kwargs)
871+
872+
response = None
873+
874+
try:
875+
response = await self._router.aimage_generation(prompt=prompt, model=self.model_name, **kwargs)
876+
877+
logger.debug(
878+
f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}",
879+
extra={"model": self.model_name, "response": response},
880+
)
881+
882+
# Validate response
883+
if not response.data or len(response.data) == 0:
884+
raise ImageGenerationError("Image generation returned no data")
885+
886+
images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None]
887+
888+
if not images:
889+
raise ImageGenerationError("No image data could be extracted from response")
890+
891+
return images
892+
893+
except Exception:
894+
raise
895+
finally:
896+
if not skip_usage_tracking and response is not None:
897+
self._track_token_usage_from_image_diffusion(response)

packages/data-designer-engine/tests/engine/models/test_facade.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
if TYPE_CHECKING:
2020
import litellm
21+
from litellm.types.utils import EmbeddingResponse, ModelResponse
2122

2223

2324
def mock_oai_response_object(response_text: str) -> StubResponse:
@@ -1403,3 +1404,108 @@ async def test_agenerate_success(
14031404
# Trace should contain at least the user prompt and the assistant response
14041405
assert any(msg.role == "user" for msg in trace)
14051406
assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace)
1407+
1408+
1409+
# =============================================================================
1410+
# Async image generation tests
1411+
# =============================================================================
1412+
1413+
1414+
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
1415+
@pytest.mark.asyncio
1416+
async def test_agenerate_image_diffusion_success(
1417+
mock_aimage_generation: AsyncMock,
1418+
stub_model_facade: ModelFacade,
1419+
) -> None:
1420+
"""Test async image generation via diffusion API."""
1421+
mock_response = litellm.types.utils.ImageResponse(
1422+
data=[
1423+
litellm.types.utils.ImageObject(b64_json="image1_base64"),
1424+
litellm.types.utils.ImageObject(b64_json="image2_base64"),
1425+
]
1426+
)
1427+
mock_aimage_generation.return_value = mock_response
1428+
1429+
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
1430+
images = await stub_model_facade.agenerate_image(prompt="test prompt")
1431+
1432+
assert len(images) == 2
1433+
assert images == ["image1_base64", "image2_base64"]
1434+
assert mock_aimage_generation.call_count == 1
1435+
# Verify image usage was tracked
1436+
assert stub_model_facade.usage_stats.image_usage.total_images == 2
1437+
1438+
1439+
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
1440+
@pytest.mark.asyncio
1441+
async def test_agenerate_image_chat_completion_success(
1442+
mock_acompletion: AsyncMock,
1443+
stub_model_facade: ModelFacade,
1444+
) -> None:
1445+
"""Test async image generation via chat completion API."""
1446+
mock_message = litellm.types.utils.Message(
1447+
role="assistant",
1448+
content="",
1449+
images=[
1450+
litellm.types.utils.ImageURLListItem(
1451+
type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0
1452+
),
1453+
],
1454+
)
1455+
mock_response = litellm.types.utils.ModelResponse(choices=[litellm.types.utils.Choices(message=mock_message)])
1456+
mock_acompletion.return_value = mock_response
1457+
1458+
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
1459+
images = await stub_model_facade.agenerate_image(prompt="test prompt")
1460+
1461+
assert len(images) == 1
1462+
assert images == ["image1"]
1463+
assert mock_acompletion.call_count == 1
1464+
assert stub_model_facade.usage_stats.image_usage.total_images == 1
1465+
1466+
1467+
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
1468+
@pytest.mark.asyncio
1469+
async def test_agenerate_image_diffusion_no_data(
1470+
mock_aimage_generation: AsyncMock,
1471+
stub_model_facade: ModelFacade,
1472+
) -> None:
1473+
"""Test async image generation raises error when diffusion API returns no data."""
1474+
mock_response = litellm.types.utils.ImageResponse(data=[])
1475+
mock_aimage_generation.return_value = mock_response
1476+
1477+
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
1478+
with pytest.raises(ImageGenerationError, match="Image generation returned no data"):
1479+
await stub_model_facade.agenerate_image(prompt="test prompt")
1480+
1481+
1482+
@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock)
1483+
@pytest.mark.asyncio
1484+
async def test_agenerate_image_chat_completion_no_choices(
1485+
mock_acompletion: AsyncMock,
1486+
stub_model_facade: ModelFacade,
1487+
) -> None:
1488+
"""Test async image generation raises error when response has no choices."""
1489+
mock_response = litellm.types.utils.ModelResponse(choices=[])
1490+
mock_acompletion.return_value = mock_response
1491+
1492+
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False):
1493+
with pytest.raises(ImageGenerationError, match="Image generation response missing choices"):
1494+
await stub_model_facade.agenerate_image(prompt="test prompt")
1495+
1496+
1497+
@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock)
1498+
@pytest.mark.asyncio
1499+
async def test_agenerate_image_skip_usage_tracking(
1500+
mock_aimage_generation: AsyncMock,
1501+
stub_model_facade: ModelFacade,
1502+
) -> None:
1503+
"""Test that async image generation respects skip_usage_tracking flag."""
1504+
mock_response = litellm.types.utils.ImageResponse(data=[litellm.types.utils.ImageObject(b64_json="image1_base64")])
1505+
mock_aimage_generation.return_value = mock_response
1506+
1507+
with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True):
1508+
images = await stub_model_facade.agenerate_image(prompt="test prompt", skip_usage_tracking=True)
1509+
1510+
assert len(images) == 1
1511+
assert stub_model_facade.usage_stats.image_usage.total_images == 0

0 commit comments

Comments
 (0)