diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_embedding.py index 8a9c6f9bd12a..6ce388d7d072 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_text_embedding.py @@ -52,7 +52,9 @@ def __init__( self._org_id = org_id self._log = log if log is not None else NullLogger() - async def generate_embeddings_async(self, texts: List[str]) -> ndarray: + async def generate_embeddings_async( + self, texts: List[str], batch_size: Optional[int] = None + ) -> ndarray: model_args = {} if self._api_type in ["azure", "azure_ad"]: model_args["engine"] = self._model_id @@ -60,18 +62,21 @@ async def generate_embeddings_async(self, texts: List[str]) -> ndarray: model_args["model"] = self._model_id try: - response: Any = await openai.Embedding.acreate( - **model_args, - api_key=self._api_key, - api_type=self._api_type, - api_base=self._endpoint, - api_version=self._api_version, - organization=self._org_id, - input=texts, - ) - - # make numpy arrays from the response - raw_embeddings = [array(x["embedding"]) for x in response["data"]] + raw_embeddings = [] + batch_size = batch_size or len(texts) + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + response: Any = await openai.Embedding.acreate( + **model_args, + api_key=self._api_key, + api_type=self._api_type, + api_base=self._endpoint, + api_version=self._api_version, + organization=self._org_id, + input=batch, + ) + # make numpy arrays from the response + raw_embeddings.extend([array(x["embedding"]) for x in response["data"]]) return array(raw_embeddings) except Exception as ex: raise AIException( diff --git a/python/tests/integration/embeddings/test_azure_oai_embedding_service.py b/python/tests/integration/embeddings/test_azure_oai_embedding_service.py index 933bac7f598b..79851f65907a 100644 --- a/python/tests/integration/embeddings/test_azure_oai_embedding_service.py +++ b/python/tests/integration/embeddings/test_azure_oai_embedding_service.py @@ -33,3 +33,23 @@ async def test_azure_text_embedding_service(create_kernel, get_aoai_config): text="this is a test", external_source_name="external source", ) + + +@pytest.mark.asyncio +async def test_batch_azure_embeddings(get_aoai_config): + # Configure LLM service + _, api_key, endpoint = get_aoai_config + + if "Python_Integration_Tests" in os.environ: + deployment_name = os.environ["AzureOpenAIEmbeddings__DeploymentName"] + + else: + deployment_name = "ada-002" + + embeddings_service = sk_oai.AzureTextEmbedding(deployment_name, endpoint, api_key) + texts = ["hello world", "goodbye world"] + results = await embeddings_service.generate_embeddings_async(texts) + batch_results = await embeddings_service.generate_embeddings_async( + texts, batch_size=1 + ) + assert len(results) == len(batch_results) diff --git a/python/tests/unit/ai/open_ai/services/test_azure_text_embedding.py b/python/tests/unit/ai/open_ai/services/test_azure_text_embedding.py index a5b0558b6f65..af7b4b94c937 100644 --- a/python/tests/unit/ai/open_ai/services/test_azure_text_embedding.py +++ b/python/tests/unit/ai/open_ai/services/test_azure_text_embedding.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from logging import Logger -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, call, patch import pytest @@ -122,7 +122,7 @@ async def test_azure_text_embedding_calls_with_parameters() -> None: api_type = "azure" api_version = "2023-03-15-preview" logger = Logger("test_logger") - texts = ["hello world"] + texts = ["hello world", "goodbye world"] azure_text_embedding = AzureTextEmbedding( deployment_name=deployment_name, @@ -143,3 +143,55 @@ async def test_azure_text_embedding_calls_with_parameters() -> None: organization=None, input=texts, ) + + +@pytest.mark.asyncio +async def test_azure_text_embedding_calls_with_batches() -> None: + mock_openai = AsyncMock() + with patch( + "semantic_kernel.connectors.ai.open_ai.services.open_ai_text_embedding.openai", + new=mock_openai, + ): + deployment_name = "test_deployment" + endpoint = "https://test-endpoint.com" + api_key = "test_api_key" + api_type = "azure" + api_version = "2023-03-15-preview" + logger = Logger("test_logger") + texts = [i for i in range(0, 5)] + + azure_text_embedding = AzureTextEmbedding( + deployment_name=deployment_name, + endpoint=endpoint, + api_key=api_key, + api_version=api_version, + logger=logger, + ) + + await azure_text_embedding.generate_embeddings_async(texts, batch_size=3) + + mock_openai.assert_has_calls( + [ + call.Embedding.acreate( + engine=deployment_name, + api_key=api_key, + api_type=api_type, + api_base=endpoint, + api_version=api_version, + organization=None, + input=texts[0:3], + ), + call.Embedding.acreate().__getitem__("data"), + call.Embedding.acreate().__getitem__().__iter__(), + call.Embedding.acreate( + engine=deployment_name, + api_key=api_key, + api_type=api_type, + api_base=endpoint, + api_version=api_version, + organization=None, + input=texts[3:5], + ), + ], + any_order=False, + )