Skip to content

Commit

Permalink
Python: Embeddings batch size (#2331)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Adding optional batch size parameter for when Azure OpenAI limits number
of inputs for embeddings requests. Fixes
#2099
### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

---------

Co-authored-by: Abby Harrison <[email protected]>
Co-authored-by: Abby Harrison <[email protected]>
  • Loading branch information
3 people authored Aug 4, 2023
1 parent 77eeff2 commit 3774258
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,31 @@ def __init__(
self._org_id = org_id
self._log = log if log is not None else NullLogger()

async def generate_embeddings_async(self, texts: List[str]) -> ndarray:
async def generate_embeddings_async(
self, texts: List[str], batch_size: Optional[int] = None
) -> ndarray:
model_args = {}
if self._api_type in ["azure", "azure_ad"]:
model_args["engine"] = self._model_id
else:
model_args["model"] = self._model_id

try:
response: Any = await openai.Embedding.acreate(
**model_args,
api_key=self._api_key,
api_type=self._api_type,
api_base=self._endpoint,
api_version=self._api_version,
organization=self._org_id,
input=texts,
)

# make numpy arrays from the response
raw_embeddings = [array(x["embedding"]) for x in response["data"]]
raw_embeddings = []
batch_size = batch_size or len(texts)
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response: Any = await openai.Embedding.acreate(
**model_args,
api_key=self._api_key,
api_type=self._api_type,
api_base=self._endpoint,
api_version=self._api_version,
organization=self._org_id,
input=batch,
)
# make numpy arrays from the response
raw_embeddings.extend([array(x["embedding"]) for x in response["data"]])
return array(raw_embeddings)
except Exception as ex:
raise AIException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,23 @@ async def test_azure_text_embedding_service(create_kernel, get_aoai_config):
text="this is a test",
external_source_name="external source",
)


@pytest.mark.asyncio
async def test_batch_azure_embeddings(get_aoai_config):
# Configure LLM service
_, api_key, endpoint = get_aoai_config

if "Python_Integration_Tests" in os.environ:
deployment_name = os.environ["AzureOpenAIEmbeddings__DeploymentName"]

else:
deployment_name = "ada-002"

embeddings_service = sk_oai.AzureTextEmbedding(deployment_name, endpoint, api_key)
texts = ["hello world", "goodbye world"]
results = await embeddings_service.generate_embeddings_async(texts)
batch_results = await embeddings_service.generate_embeddings_async(
texts, batch_size=1
)
assert len(results) == len(batch_results)
56 changes: 54 additions & 2 deletions python/tests/unit/ai/open_ai/services/test_azure_text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from logging import Logger
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, call, patch

import pytest

Expand Down Expand Up @@ -122,7 +122,7 @@ async def test_azure_text_embedding_calls_with_parameters() -> None:
api_type = "azure"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")
texts = ["hello world"]
texts = ["hello world", "goodbye world"]

azure_text_embedding = AzureTextEmbedding(
deployment_name=deployment_name,
Expand All @@ -143,3 +143,55 @@ async def test_azure_text_embedding_calls_with_parameters() -> None:
organization=None,
input=texts,
)


@pytest.mark.asyncio
async def test_azure_text_embedding_calls_with_batches() -> None:
mock_openai = AsyncMock()
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_text_embedding.openai",
new=mock_openai,
):
deployment_name = "test_deployment"
endpoint = "https://test-endpoint.com"
api_key = "test_api_key"
api_type = "azure"
api_version = "2023-03-15-preview"
logger = Logger("test_logger")
texts = [i for i in range(0, 5)]

azure_text_embedding = AzureTextEmbedding(
deployment_name=deployment_name,
endpoint=endpoint,
api_key=api_key,
api_version=api_version,
logger=logger,
)

await azure_text_embedding.generate_embeddings_async(texts, batch_size=3)

mock_openai.assert_has_calls(
[
call.Embedding.acreate(
engine=deployment_name,
api_key=api_key,
api_type=api_type,
api_base=endpoint,
api_version=api_version,
organization=None,
input=texts[0:3],
),
call.Embedding.acreate().__getitem__("data"),
call.Embedding.acreate().__getitem__().__iter__(),
call.Embedding.acreate(
engine=deployment_name,
api_key=api_key,
api_type=api_type,
api_base=endpoint,
api_version=api_version,
organization=None,
input=texts[3:5],
),
],
any_order=False,
)

0 comments on commit 3774258

Please sign in to comment.