From 85a7ef27fb786315cd33542a8650536359dd6077 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 24 May 2025 16:24:45 +0530 Subject: [PATCH 01/52] feat: Add GoogleAITextEmbedder and GoogleAIDocumentEmbedder components --- .../embedders/google_ai/__init__.py | 3 + .../embedders/google_ai/google_embedder.py | 322 ++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py create mode 100644 integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py diff --git a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py new file mode 100644 index 000000000..bbdabff16 --- /dev/null +++ b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py @@ -0,0 +1,3 @@ +from .google_embedder import GoogleAIDocumentEmbedder, GoogleAITextEmbedder + +__all__ = ["GoogleAIDocumentEmbedder", "GoogleAITextEmbedder"] \ No newline at end of file diff --git a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py new file mode 100644 index 000000000..901a47685 --- /dev/null +++ b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +from typing import Any, Dict, List, Optional, Tuple +from more_itertools import batched +from tqdm import tqdm + +from google import genai +from google.genai import types + +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +@component +class GoogleAITextEmbedder: + """ + Embeds strings using OpenAI models. + + You can use it to embed user query and send it to an embedding Retriever. + + ### Usage example + + ```python + from haystack.components.embedders import GoogleAITextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = GoogleAITextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + # 'meta': {'model': 'text-embedding-004-v2', + # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), + model: str = "text-embedding-004", + config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( + task_type="SEMANTIC_SIMILARITY"), + + ): + """ + Creates an GoogleAITextEmbedder component. + + :param api_key: + The Google API key. + You can set it with an environment variable `GOOGLE_API_KEY`, or pass with this parameter + during initialization. + :param model: + The name of the model to use for calculating embeddings. + The default model is `text-embedding-004`. + :param config: + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). + """ + + self._api_key = api_key + self._model_name = model + self._config = config + self._client = genai.Client(api_key=api_key.resolve_value()) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self._model_name, + api_key=self._api_key.to_dict(), + config=self._config.to_json_dict() + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_input(self, text: str) -> Dict[str, Any]: + if not isinstance(text, str): + raise TypeError( + "GoogleAITextEmbedder expects a string as an input." + "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." + ) + + text_to_embed = text + + kwargs: Dict[str, Any] = { + "model": self._model_name, "contents": text_to_embed} + if self._config: + kwargs["config"] = self._config + + return kwargs + + def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]: + return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}} + + @component.output_types(embedding=List[float], meta=Dict[str, Any]) + def run(self, text: str): + """ + Embeds a single string. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + create_kwargs = self._prepare_input(text=text) + response = self._client.models.embed_content(**create_kwargs) + return self._prepare_output(result=response) + + +@component +class GoogleAIDocumentEmbedder: + """ + Computes document embeddings using OpenAI models. + + ### Usage example + + ```python + from haystack import Document + from haystack.components.embedders import GoogleAIDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = GoogleAIDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), + model: str = "text-embedding-004", + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( + task_type="SEMANTIC_SIMILARITY"), + ): + """ + Creates an GoogleAIDocumentEmbedder component. + + Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' + environment variables to override the `timeout` and `max_retries` parameters respectively + in the OpenAI client. + + :param api_key: + The OpenAI API key. + You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter + during initialization. + :param model: + The name of the model to use for calculating embeddings. + The default model is `text-embedding-ada-002`. + :param batch_size: + Number of documents to embed at once. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param config: + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). + """ + self.api_key = api_key + self.model = model + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self.client = genai.Client(api_key=api_key.resolve_value()) + self.config = config + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self.model, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + api_key=self.api_key.to_dict(), + config=self.config.to_json_dict() + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIDocumentEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = {} + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + + texts_to_embed[doc.id] = ( + self.embedding_separator.join( + meta_values_to_embed + [doc.content or ""]) + ) + + return texts_to_embed + + def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + meta: Dict[str, Any] = {} + for batch in tqdm( + batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + args: Dict[str, Any] = {"model": self.model, + "contents": [b[1] for b in batch]} + if self.config: + args["config"] = self.config + + try: + response = self.client.models.embed_content(**args) + except Exception as exc: + ids = ", ".join(b[0] for b in batch) + msg = "Failed embedding of documents {ids} caused by {exc}" + logger.exception(msg, ids=ids, exc=exc) + continue + + embeddings = [el.values for el in response.embeddings] + all_embeddings.extend(embeddings) + + if "model" not in meta: + meta["model"] = self.model + + return all_embeddings, meta + + @component.output_types(documents=List[Document], meta=Dict[str, Any]) + def run(self, documents: List[Document]): + """ + Embeds a list of documents. + + :param documents: + A list of documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents with embeddings. + - `meta`: Information about the usage of the model. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "GoogleAIDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a string, please use the OpenAITextEmbedder." + ) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings, meta = self._embed_batch( + texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "meta": meta} From b9f94c762cbbbd17e41b59e3ce8fe99316f0ab38 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 24 May 2025 16:33:43 +0530 Subject: [PATCH 02/52] fix: Improve error messages for input type validation in GoogleAITextEmbedder and GoogleAIDocumentEmbedder --- .../components/embedders/google_ai/google_embedder.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py index 901a47685..9b7f43e33 100644 --- a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py +++ b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py @@ -104,11 +104,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": def _prepare_input(self, text: str) -> Dict[str, Any]: if not isinstance(text, str): - raise TypeError( - "GoogleAITextEmbedder expects a string as an input." + error_message_text = ( + "GoogleAITextEmbedder expects a string as an input. " "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." ) + raise TypeError(error_message_text) + text_to_embed = text kwargs: Dict[str, Any] = { @@ -306,10 +308,11 @@ def run(self, documents: List[Document]): - `meta`: Information about the usage of the model. """ if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): - raise TypeError( - "GoogleAIDocumentEmbedder expects a list of Documents as input." + error_message_documents = ( + "GoogleAIDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a string, please use the OpenAITextEmbedder." ) + raise TypeError(error_message_documents) texts_to_embed = self._prepare_texts_to_embed(documents=documents) From 682a4e210b3d094d42b8056524b0c2fc786b11f0 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Thu, 5 Jun 2025 10:17:39 +0530 Subject: [PATCH 03/52] feat: add Google GenAI embedder components for document and text embeddings --- .../embedders/google_ai/__init__.py | 3 - .../embedders/google_genai/__init__.py | 7 + .../google_genai/document_embedder.py} | 126 ---------------- .../embedders/google_genai/text_embedder.py | 137 ++++++++++++++++++ 4 files changed, 144 insertions(+), 129 deletions(-) delete mode 100644 integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py create mode 100644 integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py rename integrations/{google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py => google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py} (63%) create mode 100644 integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py diff --git a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py b/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py deleted file mode 100644 index bbdabff16..000000000 --- a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .google_embedder import GoogleAIDocumentEmbedder, GoogleAITextEmbedder - -__all__ = ["GoogleAIDocumentEmbedder", "GoogleAITextEmbedder"] \ No newline at end of file diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py new file mode 100644 index 000000000..3bebddbb5 --- /dev/null +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_embedder import GoogleAIDocumentEmbedder +from .text_embedder import GoogleAITextEmbedder + +__all__ = ["GoogleAIDocumentEmbedder", "GoogleAITextEmbedder"] diff --git a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py similarity index 63% rename from integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py rename to integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 9b7f43e33..f2ebe507d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/embedders/google_ai/google_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import os - from typing import Any, Dict, List, Optional, Tuple from more_itertools import batched from tqdm import tqdm @@ -17,130 +15,6 @@ logger = logging.getLogger(__name__) -@component -class GoogleAITextEmbedder: - """ - Embeds strings using OpenAI models. - - You can use it to embed user query and send it to an embedding Retriever. - - ### Usage example - - ```python - from haystack.components.embedders import GoogleAITextEmbedder - - text_to_embed = "I love pizza!" - - text_embedder = GoogleAITextEmbedder() - - print(text_embedder.run(text_to_embed)) - - # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], - # 'meta': {'model': 'text-embedding-004-v2', - # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} - ``` - """ - - def __init__( # pylint: disable=too-many-positional-arguments - self, - api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), - model: str = "text-embedding-004", - config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( - task_type="SEMANTIC_SIMILARITY"), - - ): - """ - Creates an GoogleAITextEmbedder component. - - :param api_key: - The Google API key. - You can set it with an environment variable `GOOGLE_API_KEY`, or pass with this parameter - during initialization. - :param model: - The name of the model to use for calculating embeddings. - The default model is `text-embedding-004`. - :param config: - A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. - For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). - """ - - self._api_key = api_key - self._model_name = model - self._config = config - self._client = genai.Client(api_key=api_key.resolve_value()) - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model} - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes the component to a dictionary. - - :returns: - Dictionary with serialized data. - """ - return default_to_dict( - self, - model=self._model_name, - api_key=self._api_key.to_dict(), - config=self._config.to_json_dict() - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": - """ - Deserializes the component from a dictionary. - - :param data: - Dictionary to deserialize from. - :returns: - Deserialized component. - """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - return default_from_dict(cls, data) - - def _prepare_input(self, text: str) -> Dict[str, Any]: - if not isinstance(text, str): - error_message_text = ( - "GoogleAITextEmbedder expects a string as an input. " - "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." - ) - - raise TypeError(error_message_text) - - text_to_embed = text - - kwargs: Dict[str, Any] = { - "model": self._model_name, "contents": text_to_embed} - if self._config: - kwargs["config"] = self._config - - return kwargs - - def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]: - return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}} - - @component.output_types(embedding=List[float], meta=Dict[str, Any]) - def run(self, text: str): - """ - Embeds a single string. - - :param text: - Text to embed. - - :returns: - A dictionary with the following keys: - - `embedding`: The embedding of the input text. - - `meta`: Information about the usage of the model. - """ - create_kwargs = self._prepare_input(text=text) - response = self._client.models.embed_content(**create_kwargs) - return self._prepare_output(result=response) - - @component class GoogleAIDocumentEmbedder: """ diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py new file mode 100644 index 000000000..c09e41b36 --- /dev/null +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +from google import genai +from google.genai import types + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +@component +class GoogleAITextEmbedder: + """ + Embeds strings using OpenAI models. + + You can use it to embed user query and send it to an embedding Retriever. + + ### Usage example + + ```python + from haystack.components.embedders import GoogleAITextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = GoogleAITextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + # 'meta': {'model': 'text-embedding-004-v2', + # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} + ``` + """ + + def __init__( # pylint: disable=too-many-positional-arguments + self, + api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), + model: str = "text-embedding-004", + config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( + task_type="SEMANTIC_SIMILARITY"), + + ): + """ + Creates an GoogleAITextEmbedder component. + + :param api_key: + The Google API key. + You can set it with an environment variable `GOOGLE_API_KEY`, or pass with this parameter + during initialization. + :param model: + The name of the model to use for calculating embeddings. + The default model is `text-embedding-004`. + :param config: + A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). + """ + + self._api_key = api_key + self._model_name = model + self._config = config + self._client = genai.Client(api_key=api_key.resolve_value()) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self._model_name, + api_key=self._api_key.to_dict(), + config=self._config.to_json_dict() + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_input(self, text: str) -> Dict[str, Any]: + if not isinstance(text, str): + error_message_text = ( + "GoogleAITextEmbedder expects a string as an input. " + "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." + ) + + raise TypeError(error_message_text) + + text_to_embed = text + + kwargs: Dict[str, Any] = { + "model": self._model_name, "contents": text_to_embed} + if self._config: + kwargs["config"] = self._config + + return kwargs + + def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]: + return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}} + + @component.output_types(embedding=List[float], meta=Dict[str, Any]) + def run(self, text: str): + """ + Embeds a single string. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + create_kwargs = self._prepare_input(text=text) + response = self._client.models.embed_content(**create_kwargs) + return self._prepare_output(result=response) From 778f702bd3937d9d49274eb9f6745cf158afc3c0 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Thu, 5 Jun 2025 12:57:10 +0530 Subject: [PATCH 04/52] feat: add unit tests for GoogleAIDocumentEmbedder and GoogleAITextEmbedder --- .../tests/test_document_embedder.py | 186 ++++++++++++++++++ .../google_genai/tests/test_text_embedder.py | 152 ++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 integrations/google_genai/tests/test_document_embedder.py create mode 100644 integrations/google_genai/tests/test_text_embedder.py diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py new file mode 100644 index 000000000..8017d8d4e --- /dev/null +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +from typing import List + +import pytest + +from haystack import Document +from haystack_integrations.components.embedders.google_genai import GoogleAIDocumentEmbedder +from haystack.utils.auth import Secret + +def mock_google_response(input: List[str], model: str = "text-embedding-004", **kwargs) -> dict: + dict_response = { + "embedding": [[random.random() for _ in range(768)] for _ in input], + "meta": { + "model": model + } + } + + return dict_response + + +class TestGoogleAIDocumentEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleAIDocumentEmbedder() + assert embedder.api_key.resolve_value() == "fake-api-key" + assert embedder.model == "text-embedding-004" + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_with_parameters(self, monkeypatch): + embedder = GoogleAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key-2"), + model="model", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + assert embedder.api_key.resolve_value() == "fake-api-key-2" + assert embedder.model == "model" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + embedder = GoogleAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key-2"), + model="model", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + assert embedder.api_key.resolve_value() == "fake-api-key-2" + assert embedder.model == "model" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + GoogleAIDocumentEmbedder() + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + component = GoogleAIDocumentEmbedder() + data = component.to_dict() + assert data == { + 'type': 'haystack_integrations.components.embedders.google_genai.GoogleAIDocumentEmbedder', + 'init_parameters': { + 'model': 'text-embedding-004', + 'batch_size': 32, + 'progress_bar': True, + 'meta_fields_to_embed': [], + 'embedding_separator': '\n', + 'api_key': {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, + 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + } + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "fake-api-key") + component = GoogleAIDocumentEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="model", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + data = component.to_dict() + assert data == { + 'type': 'haystack_integrations.components.embedders.google_genai.GoogleAIDocumentEmbedder', + 'init_parameters': { + 'model': 'model', + 'batch_size': 64, + 'progress_bar': False, + 'meta_fields_to_embed': ['test_field'], + 'embedding_separator': ' | ', + 'api_key': {'type': 'env_var', 'env_vars': ['ENV_VAR'], 'strict': False}, + 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + } + } + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={ + "meta_field": f"meta_value {i}"}) + for i in range(5) + ] + + embedder = GoogleAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | " + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == { + "0": "meta_value 0 | document number 0:\ncontent", + "1": "meta_value 1 | document number 1:\ncontent", + "2": "meta_value 2 | document number 2:\ncontent", + "3": "meta_value 3 | document number 3:\ncontent", + "4": "meta_value 4 | document number 4:\ncontent", + } + + def test_run_wrong_input_format(self): + embedder = GoogleAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key")) + + # wrong formats + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="GoogleAIDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + + with pytest.raises(TypeError, match="GoogleAIDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) + + def test_run_on_empty_list(self): + embedder = GoogleAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key")) + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list + + @pytest.mark.skipif(os.environ.get("GOOGLE_API_KEY", "") == "", reason="GOOGLE_API_KEY is not set") + @pytest.mark.integration + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={ + "topic": "ML"}), + ] + + model = "text-embedding-004" + + embedder = GoogleAIDocumentEmbedder(model=model, meta_fields_to_embed=[ + "topic"], embedding_separator=" | ") + + result = embedder.run(documents=docs) + documents_with_embeddings = result["documents"] + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 768 + assert all(isinstance(x, float) for x in doc.embedding) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py new file mode 100644 index 000000000..85a4981d4 --- /dev/null +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest + +from haystack_integrations.components.embedders.google_genai import GoogleAITextEmbedder +from haystack.utils.auth import Secret + +from google.genai import types +from google.genai.types import EmbedContentResponse, ContentEmbedding, EmbedContentConfig + + +class TestGoogleAITextEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleAITextEmbedder() + + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "text-embedding-004" + + def test_init_with_parameters(self): + embedder = GoogleAITextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="model", + ) + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "model" + + def test_init_with_parameters_and_env_vars(self, monkeypatch): + embedder = GoogleAITextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="model", + config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") + ) + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model_name == "model" + assert embedder._config == types.EmbedContentConfig( + task_type="SEMANTIC_SIMILARITY") + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("GOOGLE_API_KEY", raising=False) + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + GoogleAITextEmbedder() + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + component = GoogleAITextEmbedder() + data = component.to_dict() + assert data == { + "type": "aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder", + "init_parameters": { + 'api_key': {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, + "model": "text-embedding-004", + 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "fake-api-key") + component = GoogleAITextEmbedder( + api_key=Secret.from_env_var("ENV_VAR", strict=False), + model="model", + config=types.EmbedContentConfig( + task_type="SEMANTIC_SIMILARITY" + ) + ) + data = component.to_dict() + assert data == { + 'type': 'aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder', + 'init_parameters': { + 'model': 'model', + 'api_key': { + 'type': 'env_var', + 'env_vars': ['ENV_VAR'], + 'strict': False + }, + 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + } + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + data = { + "type": "aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder", + "init_parameters": { + "api_key": {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, + "model": "text-embedding-004", + }, + } + component = GoogleAITextEmbedder.from_dict(data) + assert component._api_key.resolve_value() == "fake-api-key" + assert component._model_name == "text-embedding-004" + + def test_prepare_input(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + embedder = GoogleAITextEmbedder() + + contents = "The food was delicious" + prepared_input = embedder._prepare_input(contents) + assert prepared_input == { + "model": "text-embedding-004", + "contents": "The food was delicious", + "config": EmbedContentConfig( + http_options=None, + task_type='SEMANTIC_SIMILARITY', + title=None, + output_dimensionality=None, + mime_type=None, + auto_truncate=None + ) + } + + def test_prepare_output(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") + + response = EmbedContentResponse( + embeddings=[ContentEmbedding(values=[0.1, 0.2, 0.3])], + ) + + embedder = GoogleAITextEmbedder() + result = embedder._prepare_output(result=response) + assert result == { + "embedding": [0.1, 0.2, 0.3], + "meta": {"model": "text-embedding-004"}, + } + + def test_run_wrong_input_format(self): + embedder = GoogleAITextEmbedder( + api_key=Secret.from_token("fake-api-key")) + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="GoogleAITextEmbedder expects a string as an input"): + embedder.run(text=list_integers_input) + + @pytest.mark.skipif(os.environ.get("GOOGLE_API_KEY", "") == "", reason="GOOGLE_API_KEY is not set") + @pytest.mark.integration + def test_run(self): + model = "text-embedding-004" + + embedder = GoogleAITextEmbedder(model=model) + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 768 + assert all(isinstance(x, float) for x in result["embedding"]) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) From 3de6d1e34dbf301da266a05f9e0a193198b51aec Mon Sep 17 00:00:00 2001 From: garybadwal Date: Thu, 5 Jun 2025 14:49:51 +0530 Subject: [PATCH 05/52] refactor: clean up imports and improve list handling in GoogleAIDocumentEmbedder and GoogleAITextEmbedder tests --- .../google_genai/document_embedder.py | 10 ++-- .../embedders/google_genai/text_embedder.py | 1 - .../tests/test_document_embedder.py | 52 +++++++++++-------- .../google_genai/tests/test_text_embedder.py | 35 ++++++------- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index f2ebe507d..58520a31b 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -3,14 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict, List, Optional, Tuple -from more_itertools import batched -from tqdm import tqdm from google import genai from google.genai import types - from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from more_itertools import batched +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -132,7 +131,8 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: texts_to_embed[doc.id] = ( self.embedding_separator.join( - meta_values_to_embed + [doc.content or ""]) + [*meta_values_to_embed, doc.content or ""] + ) ) return texts_to_embed @@ -181,7 +181,7 @@ def run(self, documents: List[Document]): - `documents`: A list of documents with embeddings. - `meta`: Information about the usage of the model. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): error_message_documents = ( "GoogleAIDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a string, please use the OpenAITextEmbedder." diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index c09e41b36..0c0d9fbcf 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -6,7 +6,6 @@ from google import genai from google.genai import types - from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 8017d8d4e..534b4c9a3 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -7,14 +7,16 @@ from typing import List import pytest - from haystack import Document -from haystack_integrations.components.embedders.google_genai import GoogleAIDocumentEmbedder from haystack.utils.auth import Secret -def mock_google_response(input: List[str], model: str = "text-embedding-004", **kwargs) -> dict: +from haystack_integrations.components.embedders.google_genai import GoogleAIDocumentEmbedder + + +def mock_google_response(contents: List[str], model: str = "text-embedding-004", **kwargs) -> dict: + secure_random = random.SystemRandom() dict_response = { - "embedding": [[random.random() for _ in range(768)] for _ in input], + "embedding": [[secure_random.random() for _ in range(768)] for _ in contents], "meta": { "model": model } @@ -76,15 +78,18 @@ def test_to_dict(self, monkeypatch): component = GoogleAIDocumentEmbedder() data = component.to_dict() assert data == { - 'type': 'haystack_integrations.components.embedders.google_genai.GoogleAIDocumentEmbedder', - 'init_parameters': { - 'model': 'text-embedding-004', - 'batch_size': 32, - 'progress_bar': True, - 'meta_fields_to_embed': [], - 'embedding_separator': '\n', - 'api_key': {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, - 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + "type": ( + "haystack_integrations.components.embedders." + "google_genai.document_embedder.GoogleAIDocumentEmbedder" + ), + "init_parameters": { + "model": "text-embedding-004", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, + "config": {"task_type": "SEMANTIC_SIMILARITY"} } } @@ -100,15 +105,18 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): ) data = component.to_dict() assert data == { - 'type': 'haystack_integrations.components.embedders.google_genai.GoogleAIDocumentEmbedder', - 'init_parameters': { - 'model': 'model', - 'batch_size': 64, - 'progress_bar': False, - 'meta_fields_to_embed': ['test_field'], - 'embedding_separator': ' | ', - 'api_key': {'type': 'env_var', 'env_vars': ['ENV_VAR'], 'strict': False}, - 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + "type": ( + "haystack_integrations.components.embedders." + "google_genai.document_embedder.GoogleAIDocumentEmbedder" + ), + "init_parameters": { + "model": "model", + "batch_size": 64, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + "api_key": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "config": {"task_type": "SEMANTIC_SIMILARITY"} } } diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py index 85a4981d4..7327d3901 100644 --- a/integrations/google_genai/tests/test_text_embedder.py +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -5,12 +5,11 @@ import os import pytest - -from haystack_integrations.components.embedders.google_genai import GoogleAITextEmbedder +from google.genai import types +from google.genai.types import ContentEmbedding, EmbedContentConfig, EmbedContentResponse from haystack.utils.auth import Secret -from google.genai import types -from google.genai.types import EmbedContentResponse, ContentEmbedding, EmbedContentConfig +from haystack_integrations.components.embedders.google_genai import GoogleAITextEmbedder class TestGoogleAITextEmbedder: @@ -50,11 +49,11 @@ def test_to_dict(self, monkeypatch): component = GoogleAITextEmbedder() data = component.to_dict() assert data == { - "type": "aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder", + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", "init_parameters": { - 'api_key': {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, "model": "text-embedding-004", - 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + "config": {"task_type": "SEMANTIC_SIMILARITY"} }, } @@ -69,24 +68,24 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): ) data = component.to_dict() assert data == { - 'type': 'aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder', - 'init_parameters': { - 'model': 'model', - 'api_key': { - 'type': 'env_var', - 'env_vars': ['ENV_VAR'], - 'strict': False + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", + "init_parameters": { + "model": "model", + "api_key": { + "type": "env_var", + "env_vars": ["ENV_VAR"], + "strict": False }, - 'config': {'task_type': 'SEMANTIC_SIMILARITY'} + "config": {"task_type": "SEMANTIC_SIMILARITY"} } } def test_from_dict(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") data = { - "type": "aystack_integrations.components.embedders.google_genai.GoogleAITextEmbedder", + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", "init_parameters": { - "api_key": {'type': 'env_var', 'env_vars': ['GOOGLE_API_KEY'], 'strict': True}, + "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, "model": "text-embedding-004", }, } @@ -105,7 +104,7 @@ def test_prepare_input(self, monkeypatch): "contents": "The food was delicious", "config": EmbedContentConfig( http_options=None, - task_type='SEMANTIC_SIMILARITY', + task_type="SEMANTIC_SIMILARITY", title=None, output_dimensionality=None, mime_type=None, From 9c6cb1a3b118178b542720ba6ca070144f0a39ec Mon Sep 17 00:00:00 2001 From: garybadwal Date: Thu, 5 Jun 2025 20:44:24 +0530 Subject: [PATCH 06/52] refactor: Rename classes and update imports for Google GenAI components --- .../embedders/google_genai/__init__.py | 6 +- .../google_genai/document_embedder.py | 64 +++++------ .../embedders/google_genai/text_embedder.py | 52 ++++----- .../tests/test_document_embedder.py | 91 +++++++++------- .../google_genai/tests/test_text_embedder.py | 100 ++++++++++-------- 5 files changed, 173 insertions(+), 140 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py index 3bebddbb5..f426cd628 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/__init__.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .document_embedder import GoogleAIDocumentEmbedder -from .text_embedder import GoogleAITextEmbedder +from .document_embedder import GoogleGenAIDocumentEmbedder +from .text_embedder import GoogleGenAITextEmbedder -__all__ = ["GoogleAIDocumentEmbedder", "GoogleAITextEmbedder"] +__all__ = ["GoogleGenAIDocumentEmbedder", "GoogleGenAITextEmbedder"] diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 58520a31b..07c0c16f7 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from google import genai from google.genai import types @@ -15,19 +15,19 @@ @component -class GoogleAIDocumentEmbedder: +class GoogleGenAIDocumentEmbedder: """ - Computes document embeddings using OpenAI models. + Computes document embeddings using Google AI models. ### Usage example ```python from haystack import Document - from haystack.components.embedders import GoogleAIDocumentEmbedder + from haystack_integrations.components.embedders import GoogleGenAIDocumentEmbedder doc = Document(content="I love pizza!") - document_embedder = GoogleAIDocumentEmbedder() + document_embedder = GoogleGenAIDocumentEmbedder() result = document_embedder.run([doc]) print(result['documents'][0].embedding) @@ -38,29 +38,35 @@ class GoogleAIDocumentEmbedder: def __init__( # pylint: disable=too-many-positional-arguments self, + *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), model: str = "text-embedding-004", + prefix: str = "", + suffix: str = "", batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", - config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( - task_type="SEMANTIC_SIMILARITY"), + config: Optional[Dict[str, Any]] = None, ): """ - Creates an GoogleAIDocumentEmbedder component. + Creates an GoogleGenAIDocumentEmbedder component. - Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' + Before initializing the component, you can set the 'GoogleGenAI_TIMEOUT' and 'GoogleGenAI_MAX_RETRIES' environment variables to override the `timeout` and `max_retries` parameters respectively - in the OpenAI client. + in the GoogleGenAI client. :param api_key: - The OpenAI API key. - You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter + The Google API key. + You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter during initialization. :param model: The name of the model to use for calculating embeddings. The default model is `text-embedding-ada-002`. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. :param batch_size: Number of documents to embed at once. :param progress_bar: @@ -75,18 +81,14 @@ def __init__( # pylint: disable=too-many-positional-arguments """ self.api_key = api_key self.model = model + self.prefix = prefix + self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator self.client = genai.Client(api_key=api_key.resolve_value()) - self.config = config - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model} + self.config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} def to_dict(self) -> Dict[str, Any]: """ @@ -98,16 +100,18 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model=self.model, + prefix=self.prefix, + suffix=self.suffix, batch_size=self.batch_size, progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, api_key=self.api_key.to_dict(), - config=self.config.to_json_dict() + config=self.config, ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIDocumentEmbedder": + def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAIDocumentEmbedder": """ Deserializes the component from a dictionary. @@ -130,9 +134,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: ] texts_to_embed[doc.id] = ( - self.embedding_separator.join( - [*meta_values_to_embed, doc.content or ""] - ) + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix ) return texts_to_embed @@ -147,10 +149,9 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple for batch in tqdm( batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): - args: Dict[str, Any] = {"model": self.model, - "contents": [b[1] for b in batch]} + args: Dict[str, Any] = {"model": self.model, "contents": [b[1] for b in batch]} if self.config: - args["config"] = self.config + args["config"] = types.EmbedContentConfig(**self.config) if self.config else None try: response = self.client.models.embed_content(**args) @@ -169,7 +170,7 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple return all_embeddings, meta @component.output_types(documents=List[Document], meta=Dict[str, Any]) - def run(self, documents: List[Document]): + def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: """ Embeds a list of documents. @@ -183,15 +184,14 @@ def run(self, documents: List[Document]): """ if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): error_message_documents = ( - "GoogleAIDocumentEmbedder expects a list of Documents as input. " - "In case you want to embed a string, please use the OpenAITextEmbedder." + "GoogleGenAIDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the GoogleGenAITextEmbedder." ) raise TypeError(error_message_documents) texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( - texts_to_embed=texts_to_embed, batch_size=self.batch_size) + embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index 0c0d9fbcf..7f8608c1e 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from google import genai from google.genai import types @@ -13,20 +13,20 @@ @component -class GoogleAITextEmbedder: +class GoogleGenAITextEmbedder: """ - Embeds strings using OpenAI models. + Embeds strings using Google AI models. You can use it to embed user query and send it to an embedding Retriever. ### Usage example ```python - from haystack.components.embedders import GoogleAITextEmbedder + from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder text_to_embed = "I love pizza!" - text_embedder = GoogleAITextEmbedder() + text_embedder = GoogleGenAITextEmbedder() print(text_embedder.run(text_to_embed)) @@ -38,22 +38,27 @@ class GoogleAITextEmbedder: def __init__( # pylint: disable=too-many-positional-arguments self, + *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), model: str = "text-embedding-004", - config: Optional[types.EmbedContentConfig] = types.EmbedContentConfig( - task_type="SEMANTIC_SIMILARITY"), - + prefix: str = "", + suffix: str = "", + config: Optional[Dict[str, Any]] = None, ): """ - Creates an GoogleAITextEmbedder component. + Creates an GoogleGenAITextEmbedder component. :param api_key: The Google API key. - You can set it with an environment variable `GOOGLE_API_KEY`, or pass with this parameter + You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter during initialization. :param model: The name of the model to use for calculating embeddings. The default model is `text-embedding-004`. + :param prefix: + A string to add at the beginning of each text to embed. + :param suffix: + A string to add at the end of each text to embed. :param config: A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). @@ -61,15 +66,11 @@ def __init__( # pylint: disable=too-many-positional-arguments self._api_key = api_key self._model_name = model - self._config = config + self._prefix = prefix + self._suffix = suffix + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} self._client = genai.Client(api_key=api_key.resolve_value()) - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model} - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -81,11 +82,13 @@ def to_dict(self) -> Dict[str, Any]: self, model=self._model_name, api_key=self._api_key.to_dict(), - config=self._config.to_json_dict() + prefix=self._prefix, + suffix=self._suffix, + config=self._config, ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": + def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAITextEmbedder": """ Deserializes the component from a dictionary. @@ -100,18 +103,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAITextEmbedder": def _prepare_input(self, text: str) -> Dict[str, Any]: if not isinstance(text, str): error_message_text = ( - "GoogleAITextEmbedder expects a string as an input. " + "GoogleGenAITextEmbedder expects a string as an input. " "In case you want to embed a list of Documents, please use the GoogleAIDocumentEmbedder." ) raise TypeError(error_message_text) - text_to_embed = text + text_to_embed = self._prefix + text + self._suffix - kwargs: Dict[str, Any] = { - "model": self._model_name, "contents": text_to_embed} + kwargs: Dict[str, Any] = {"model": self._model_name, "contents": text_to_embed} if self._config: - kwargs["config"] = self._config + kwargs["config"] = types.EmbedContentConfig(**self._config) return kwargs @@ -119,7 +121,7 @@ def _prepare_output(self, result: types.EmbedContentResponse) -> Dict[str, Any]: return {"embedding": result.embeddings[0].values, "meta": {"model": self._model_name}} @component.output_types(embedding=List[float], meta=Dict[str, Any]) - def run(self, text: str): + def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]: """ Embeds a single string. diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 534b4c9a3..eedf2f13a 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -10,129 +10,147 @@ from haystack import Document from haystack.utils.auth import Secret -from haystack_integrations.components.embedders.google_genai import GoogleAIDocumentEmbedder +from haystack_integrations.components.embedders.google_genai import GoogleGenAIDocumentEmbedder def mock_google_response(contents: List[str], model: str = "text-embedding-004", **kwargs) -> dict: secure_random = random.SystemRandom() dict_response = { "embedding": [[secure_random.random() for _ in range(768)] for _ in contents], - "meta": { - "model": model - } + "meta": {"model": model}, } return dict_response -class TestGoogleAIDocumentEmbedder: +class TestGoogleGenAIDocumentEmbedder: def test_init_default(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") - embedder = GoogleAIDocumentEmbedder() + embedder = GoogleGenAIDocumentEmbedder() assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.model == "text-embedding-004" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.config == {"task_type": "SEMANTIC_SIMILARITY"} def test_init_with_parameters(self, monkeypatch): - embedder = GoogleAIDocumentEmbedder( + embedder = GoogleGenAIDocumentEmbedder( api_key=Secret.from_token("fake-api-key-2"), model="model", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, ) assert embedder.api_key.resolve_value() == "fake-api-key-2" assert embedder.model == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " + assert embedder.config == {"task_type": "CLASSIFICATION"} def test_init_with_parameters_and_env_vars(self, monkeypatch): - embedder = GoogleAIDocumentEmbedder( + embedder = GoogleGenAIDocumentEmbedder( api_key=Secret.from_token("fake-api-key-2"), model="model", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, ) assert embedder.api_key.resolve_value() == "fake-api-key-2" assert embedder.model == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " + assert embedder.config == {"task_type": "CLASSIFICATION"} def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("GOOGLE_API_KEY", raising=False) with pytest.raises(ValueError, match="None of the .* environment variables are set"): - GoogleAIDocumentEmbedder() + GoogleGenAIDocumentEmbedder() def test_to_dict(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") - component = GoogleAIDocumentEmbedder() + component = GoogleGenAIDocumentEmbedder() data = component.to_dict() assert data == { "type": ( - "haystack_integrations.components.embedders." - "google_genai.document_embedder.GoogleAIDocumentEmbedder" + "haystack_integrations.components.embedders" + ".google_genai.document_embedder.GoogleGenAIDocumentEmbedder" ), "init_parameters": { "model": "text-embedding-004", + "prefix": "", + "suffix": "", "batch_size": 32, "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, - "config": {"task_type": "SEMANTIC_SIMILARITY"} - } + "config": {"task_type": "SEMANTIC_SIMILARITY"}, + }, } def test_to_dict_with_custom_init_parameters(self, monkeypatch): monkeypatch.setenv("ENV_VAR", "fake-api-key") - component = GoogleAIDocumentEmbedder( + component = GoogleGenAIDocumentEmbedder( api_key=Secret.from_env_var("ENV_VAR", strict=False), model="model", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], embedding_separator=" | ", + config={"task_type": "CLASSIFICATION"}, ) data = component.to_dict() assert data == { "type": ( - "haystack_integrations.components.embedders." - "google_genai.document_embedder.GoogleAIDocumentEmbedder" + "haystack_integrations.components.embedders" + ".google_genai.document_embedder.GoogleGenAIDocumentEmbedder" ), "init_parameters": { "model": "model", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", "api_key": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, - "config": {"task_type": "SEMANTIC_SIMILARITY"} - } + "config": {"task_type": "CLASSIFICATION"}, + }, } def test_prepare_texts_to_embed_w_metadata(self): documents = [ - Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={ - "meta_field": f"meta_value {i}"}) + Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] - embedder = GoogleAIDocumentEmbedder( + embedder = GoogleGenAIDocumentEmbedder( api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | " ) prepared_texts = embedder._prepare_texts_to_embed(documents) - assert prepared_texts == { "0": "meta_value 0 | document number 0:\ncontent", "1": "meta_value 1 | document number 1:\ncontent", @@ -142,22 +160,20 @@ def test_prepare_texts_to_embed_w_metadata(self): } def test_run_wrong_input_format(self): - embedder = GoogleAIDocumentEmbedder( - api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) # wrong formats string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises(TypeError, match="GoogleAIDocumentEmbedder expects a list of Documents as input"): + with pytest.raises(TypeError, match="GoogleGenAIDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=string_input) - with pytest.raises(TypeError, match="GoogleAIDocumentEmbedder expects a list of Documents as input"): + with pytest.raises(TypeError, match="GoogleGenAIDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): - embedder = GoogleAIDocumentEmbedder( - api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) empty_list_input = [] result = embedder.run(documents=empty_list_input) @@ -165,19 +181,20 @@ def test_run_on_empty_list(self): assert result["documents"] is not None assert not result["documents"] # empty list - @pytest.mark.skipif(os.environ.get("GOOGLE_API_KEY", "") == "", reason="GOOGLE_API_KEY is not set") + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) @pytest.mark.integration def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={ - "topic": "ML"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "text-embedding-004" - embedder = GoogleAIDocumentEmbedder(model=model, meta_fields_to_embed=[ - "topic"], embedding_separator=" | ") + embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") result = embedder.run(documents=docs) documents_with_embeddings = result["documents"] @@ -189,6 +206,6 @@ def test_run(self): assert len(doc.embedding) == 768 assert all(isinstance(x, float) for x in doc.embedding) - assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( - "The model name does not contain 'text' and '004'" - ) + assert ( + "text" in result["meta"]["model"] and "004" in result["meta"]["model"] + ), "The model name does not contain 'text' and '004'" diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py index 7327d3901..e5af84f66 100644 --- a/integrations/google_genai/tests/test_text_embedder.py +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -5,97 +5,109 @@ import os import pytest -from google.genai import types from google.genai.types import ContentEmbedding, EmbedContentConfig, EmbedContentResponse from haystack.utils.auth import Secret -from haystack_integrations.components.embedders.google_genai import GoogleAITextEmbedder +from haystack_integrations.components.embedders.google_genai import GoogleGenAITextEmbedder -class TestGoogleAITextEmbedder: +class TestGoogleGenAITextEmbedder: def test_init_default(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") - embedder = GoogleAITextEmbedder() + embedder = GoogleGenAITextEmbedder() assert embedder._api_key.resolve_value() == "fake-api-key" assert embedder._model_name == "text-embedding-004" + assert embedder._prefix == "" + assert embedder._suffix == "" + assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"} def test_init_with_parameters(self): - embedder = GoogleAITextEmbedder( + embedder = GoogleGenAITextEmbedder( api_key=Secret.from_token("fake-api-key"), model="model", + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, ) assert embedder._api_key.resolve_value() == "fake-api-key" assert embedder._model_name == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._config == {"task_type": "CLASSIFICATION"} def test_init_with_parameters_and_env_vars(self, monkeypatch): - embedder = GoogleAITextEmbedder( + embedder = GoogleGenAITextEmbedder( api_key=Secret.from_token("fake-api-key"), model="model", - config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, ) assert embedder._api_key.resolve_value() == "fake-api-key" assert embedder._model_name == "model" - assert embedder._config == types.EmbedContentConfig( - task_type="SEMANTIC_SIMILARITY") - - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("GOOGLE_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): - GoogleAITextEmbedder() + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._config == {"task_type": "CLASSIFICATION"} def test_to_dict(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") - component = GoogleAITextEmbedder() + component = GoogleGenAITextEmbedder() data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", "init_parameters": { "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, "model": "text-embedding-004", - "config": {"task_type": "SEMANTIC_SIMILARITY"} + "prefix": "", + "suffix": "", + "config": {"task_type": "SEMANTIC_SIMILARITY"}, }, } def test_to_dict_with_custom_init_parameters(self, monkeypatch): monkeypatch.setenv("ENV_VAR", "fake-api-key") - component = GoogleAITextEmbedder( + component = GoogleGenAITextEmbedder( api_key=Secret.from_env_var("ENV_VAR", strict=False), model="model", - config=types.EmbedContentConfig( - task_type="SEMANTIC_SIMILARITY" - ) + prefix="prefix", + suffix="suffix", + config={"task_type": "CLASSIFICATION"}, ) data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", "init_parameters": { "model": "model", - "api_key": { - "type": "env_var", - "env_vars": ["ENV_VAR"], - "strict": False - }, - "config": {"task_type": "SEMANTIC_SIMILARITY"} - } + "api_key": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "prefix": "prefix", + "suffix": "suffix", + "config": {"task_type": "CLASSIFICATION"}, + }, } def test_from_dict(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") data = { - "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleAITextEmbedder", + "type": "haystack_integrations.components.embedders.google_genai.text_embedder.GoogleGenAITextEmbedder", "init_parameters": { "api_key": {"type": "env_var", "env_vars": ["GOOGLE_API_KEY"], "strict": True}, "model": "text-embedding-004", + "prefix": "", + "suffix": "", + "config": {"task_type": "CLASSIFICATION"}, }, } - component = GoogleAITextEmbedder.from_dict(data) + component = GoogleGenAITextEmbedder.from_dict(data) assert component._api_key.resolve_value() == "fake-api-key" assert component._model_name == "text-embedding-004" + assert component._prefix == "" + assert component._suffix == "" + assert component._config == {"task_type": "CLASSIFICATION"} def test_prepare_input(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") - embedder = GoogleAITextEmbedder() + embedder = GoogleGenAITextEmbedder() contents = "The food was delicious" prepared_input = embedder._prepare_input(contents) @@ -108,8 +120,8 @@ def test_prepare_input(self, monkeypatch): title=None, output_dimensionality=None, mime_type=None, - auto_truncate=None - ) + auto_truncate=None, + ), } def test_prepare_output(self, monkeypatch): @@ -119,7 +131,7 @@ def test_prepare_output(self, monkeypatch): embeddings=[ContentEmbedding(values=[0.1, 0.2, 0.3])], ) - embedder = GoogleAITextEmbedder() + embedder = GoogleGenAITextEmbedder() result = embedder._prepare_output(result=response) assert result == { "embedding": [0.1, 0.2, 0.3], @@ -127,25 +139,27 @@ def test_prepare_output(self, monkeypatch): } def test_run_wrong_input_format(self): - embedder = GoogleAITextEmbedder( - api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAITextEmbedder(api_key=Secret.from_token("fake-api-key")) list_integers_input = [1, 2, 3] - with pytest.raises(TypeError, match="GoogleAITextEmbedder expects a string as an input"): + with pytest.raises(TypeError, match="GoogleGenAITextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) - @pytest.mark.skipif(os.environ.get("GOOGLE_API_KEY", "") == "", reason="GOOGLE_API_KEY is not set") + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) @pytest.mark.integration def test_run(self): model = "text-embedding-004" - embedder = GoogleAITextEmbedder(model=model) + embedder = GoogleGenAITextEmbedder(model=model) result = embedder.run(text="The food was delicious") assert len(result["embedding"]) == 768 assert all(isinstance(x, float) for x in result["embedding"]) - assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( - "The model name does not contain 'text' and '004'" - ) + assert ( + "text" in result["meta"]["model"] and "004" in result["meta"]["model"] + ), "The model name does not contain 'text' and '004'" From 89bb3becf0a57ef5ad742036c98dc7ebda0f467b Mon Sep 17 00:00:00 2001 From: garybadwal Date: Thu, 5 Jun 2025 20:47:23 +0530 Subject: [PATCH 07/52] feat: Add additional modules for Google GenAI embedders in config --- integrations/google_genai/pydoc/config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/google_genai/pydoc/config.yml b/integrations/google_genai/pydoc/config.yml index e87f53cd0..095a67c07 100644 --- a/integrations/google_genai/pydoc/config.yml +++ b/integrations/google_genai/pydoc/config.yml @@ -3,6 +3,8 @@ loaders: search_path: [../src] modules: [ "haystack_integrations.components.generators.google_genai.chat.chat_generator", + "haystack_integrations.components.embedders.google_genai.document_embedder", + "haystack_integrations.components.embedders.google_genai.text_embedder" ] ignore_when_discovered: ["__init__"] processors: From f2d2a0c1929effd8aaace4efb4de87588d65a4f9 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:16:33 +0530 Subject: [PATCH 08/52] chore: add 'more-itertools' to lint environment dependencies --- integrations/google_genai/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_genai/pyproject.toml b/integrations/google_genai/pyproject.toml index 41021b5ff..e3c94cc60 100644 --- a/integrations/google_genai/pyproject.toml +++ b/integrations/google_genai/pyproject.toml @@ -74,7 +74,7 @@ types = "mypy --install-types --non-interactive --explicit-package-bases {args:s [tool.hatch.envs.lint] installer = "uv" detached = true -dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "more-itertools"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" From 171ab378b4336d04c6c784a2b52bfb4889a528bd Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:27:59 +0530 Subject: [PATCH 09/52] refactor: update GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder to use private attributes for initialization --- .../google_genai/document_embedder.py | 78 +++++++++---------- .../embedders/google_genai/text_embedder.py | 3 +- .../tests/test_document_embedder.py | 54 ++++++------- 3 files changed, 64 insertions(+), 71 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 07c0c16f7..ff457891f 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -36,7 +36,7 @@ class GoogleGenAIDocumentEmbedder: ``` """ - def __init__( # pylint: disable=too-many-positional-arguments + def __init__( self, *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), @@ -52,10 +52,6 @@ def __init__( # pylint: disable=too-many-positional-arguments """ Creates an GoogleGenAIDocumentEmbedder component. - Before initializing the component, you can set the 'GoogleGenAI_TIMEOUT' and 'GoogleGenAI_MAX_RETRIES' - environment variables to override the `timeout` and `max_retries` parameters respectively - in the GoogleGenAI client. - :param api_key: The Google API key. You can set it with the environment variable `GOOGLE_API_KEY`, or pass it via this parameter @@ -77,18 +73,19 @@ def __init__( # pylint: disable=too-many-positional-arguments Separator used to concatenate the metadata fields to the document text. :param config: A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}. For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). """ - self.api_key = api_key - self.model = model - self.prefix = prefix - self.suffix = suffix - self.batch_size = batch_size - self.progress_bar = progress_bar - self.meta_fields_to_embed = meta_fields_to_embed or [] - self.embedding_separator = embedding_separator - self.client = genai.Client(api_key=api_key.resolve_value()) - self.config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} + self._api_key = api_key + self._model = model + self._prefix = prefix + self._suffix = suffix + self._batch_size = batch_size + self._progress_bar = progress_bar + self._meta_fields_to_embed = meta_fields_to_embed or [] + self._embedding_separator = embedding_separator + self._client = genai.Client(api_key=api_key.resolve_value()) + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} def to_dict(self) -> Dict[str, Any]: """ @@ -99,15 +96,15 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model=self.model, - prefix=self.prefix, - suffix=self.suffix, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - meta_fields_to_embed=self.meta_fields_to_embed, - embedding_separator=self.embedding_separator, - api_key=self.api_key.to_dict(), - config=self.config, + model=self._model, + prefix=self._prefix, + suffix=self._suffix, + batch_size=self._batch_size, + progress_bar=self._progress_bar, + meta_fields_to_embed=self._meta_fields_to_embed, + embedding_separator=self._embedding_separator, + api_key=self._api_key.to_dict(), + config=self._config, ) @classmethod @@ -127,19 +124,20 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: """ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. """ - texts_to_embed = {} + texts_to_embed: List[str] = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) for key in self._meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] - texts_to_embed[doc.id] = ( - self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + text_to_embed = ( + self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix ) + texts_to_embed.append(text_to_embed) return texts_to_embed - def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: """ Embed a list of texts in batches. """ @@ -147,25 +145,19 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple all_embeddings = [] meta: Dict[str, Any] = {} for batch in tqdm( - batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" ): - args: Dict[str, Any] = {"model": self.model, "contents": [b[1] for b in batch]} - if self.config: - args["config"] = types.EmbedContentConfig(**self.config) if self.config else None - - try: - response = self.client.models.embed_content(**args) - except Exception as exc: - ids = ", ".join(b[0] for b in batch) - msg = "Failed embedding of documents {ids} caused by {exc}" - logger.exception(msg, ids=ids, exc=exc) - continue + args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} + if self._config: + args["config"] = types.EmbedContentConfig(**self._config) if self._config else None + + response = self._client.models.embed_content(**args) embeddings = [el.values for el in response.embeddings] all_embeddings.extend(embeddings) if "model" not in meta: - meta["model"] = self.model + meta["model"] = self._model return all_embeddings, meta @@ -191,7 +183,7 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index 7f8608c1e..415d5fc21 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -36,7 +36,7 @@ class GoogleGenAITextEmbedder: ``` """ - def __init__( # pylint: disable=too-many-positional-arguments + def __init__( self, *, api_key: Secret = Secret.from_env_var("GOOGLE_API_KEY"), @@ -61,6 +61,7 @@ def __init__( # pylint: disable=too-many-positional-arguments A string to add at the end of each text to embed. :param config: A dictionary of keyword arguments to configure embedding content configuration `types.EmbedContentConfig`. + If not specified, it defaults to {"task_type": "SEMANTIC_SIMILARITY"}. For more information, see the [Google AI Task types](https://ai.google.dev/gemini-api/docs/embeddings#task-types). """ diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index eedf2f13a..0d59ba462 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -27,15 +27,15 @@ class TestGoogleGenAIDocumentEmbedder: def test_init_default(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "fake-api-key") embedder = GoogleGenAIDocumentEmbedder() - assert embedder.api_key.resolve_value() == "fake-api-key" - assert embedder.model == "text-embedding-004" - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.batch_size == 32 - assert embedder.progress_bar is True - assert embedder.meta_fields_to_embed == [] - assert embedder.embedding_separator == "\n" - assert embedder.config == {"task_type": "SEMANTIC_SIMILARITY"} + assert embedder._api_key.resolve_value() == "fake-api-key" + assert embedder._model == "text-embedding-004" + assert embedder._prefix == "" + assert embedder._suffix == "" + assert embedder._batch_size == 32 + assert embedder._progress_bar is True + assert embedder._meta_fields_to_embed == [] + assert embedder._embedding_separator == "\n" + assert embedder._config == {"task_type": "SEMANTIC_SIMILARITY"} def test_init_with_parameters(self, monkeypatch): embedder = GoogleGenAIDocumentEmbedder( @@ -49,15 +49,15 @@ def test_init_with_parameters(self, monkeypatch): embedding_separator=" | ", config={"task_type": "CLASSIFICATION"}, ) - assert embedder.api_key.resolve_value() == "fake-api-key-2" - assert embedder.model == "model" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False - assert embedder.meta_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " - assert embedder.config == {"task_type": "CLASSIFICATION"} + assert embedder._api_key.resolve_value() == "fake-api-key-2" + assert embedder._model == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._batch_size == 64 + assert embedder._progress_bar is False + assert embedder._meta_fields_to_embed == ["test_field"] + assert embedder._embedding_separator == " | " + assert embedder._config == {"task_type": "CLASSIFICATION"} def test_init_with_parameters_and_env_vars(self, monkeypatch): embedder = GoogleGenAIDocumentEmbedder( @@ -71,15 +71,15 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): embedding_separator=" | ", config={"task_type": "CLASSIFICATION"}, ) - assert embedder.api_key.resolve_value() == "fake-api-key-2" - assert embedder.model == "model" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False - assert embedder.meta_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " - assert embedder.config == {"task_type": "CLASSIFICATION"} + assert embedder._api_key.resolve_value() == "fake-api-key-2" + assert embedder._model == "model" + assert embedder._prefix == "prefix" + assert embedder._suffix == "suffix" + assert embedder._batch_size == 64 + assert embedder._progress_bar is False + assert embedder._meta_fields_to_embed == ["test_field"] + assert embedder._embedding_separator == " | " + assert embedder._config == {"task_type": "CLASSIFICATION"} def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("GOOGLE_API_KEY", raising=False) From f20bdffcd5bd8a10477ca607c0706bad42e35cff Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:34:39 +0530 Subject: [PATCH 10/52] refactor: update _prepare_texts_to_embed to return a list instead of a dictionary --- .../embedders/google_genai/document_embedder.py | 5 ++++- .../google_genai/tests/test_document_embedder.py | 14 +++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index ff457891f..fcd0d9f44 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -127,7 +127,10 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: texts_to_embed: List[str] = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) for key in self._meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + str(doc.meta[key]) + for key in + self._meta_fields_to_embed + if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 0d59ba462..21936d2f6 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -151,13 +151,13 @@ def test_prepare_texts_to_embed_w_metadata(self): ) prepared_texts = embedder._prepare_texts_to_embed(documents) - assert prepared_texts == { - "0": "meta_value 0 | document number 0:\ncontent", - "1": "meta_value 1 | document number 1:\ncontent", - "2": "meta_value 2 | document number 2:\ncontent", - "3": "meta_value 3 | document number 3:\ncontent", - "4": "meta_value 4 | document number 4:\ncontent", - } + assert prepared_texts == [ + 'meta_value 0 | document number 0:\ncontent', + 'meta_value 1 | document number 1:\ncontent', + 'meta_value 2 | document number 2:\ncontent', + 'meta_value 3 | document number 3:\ncontent', + 'meta_value 4 | document number 4:\ncontent' + ] def test_run_wrong_input_format(self): embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) From 38525c05b63d688cd545d67b1c9f539cebd1dfc5 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:38:14 +0530 Subject: [PATCH 11/52] refactor: format code for better readability and consistency in document embedder --- .../google_genai/document_embedder.py | 20 +++++++++------ .../tests/test_document_embedder.py | 25 +++++++++++-------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index fcd0d9f44..f38082b33 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -85,7 +85,8 @@ def __init__( self._meta_fields_to_embed = meta_fields_to_embed or [] self._embedding_separator = embedding_separator self._client = genai.Client(api_key=api_key.resolve_value()) - self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} + self._config = config if config is not None else { + "task_type": "SEMANTIC_SIMILARITY"} def to_dict(self) -> Dict[str, Any]: """ @@ -127,14 +128,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: texts_to_embed: List[str] = [] for doc in documents: meta_values_to_embed = [ - str(doc.meta[key]) - for key in - self._meta_fields_to_embed + str(doc.meta[key]) + for key in self._meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] text_to_embed = ( - self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix + self._prefix + self._embedding_separator.join( + [*meta_values_to_embed, doc.content or ""]) + self._suffix ) texts_to_embed.append(text_to_embed) @@ -150,9 +151,11 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List for batch in tqdm( batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" ): - args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} + args: Dict[str, Any] = {"model": self._model, + "contents": [b[1] for b in batch]} if self._config: - args["config"] = types.EmbedContentConfig(**self._config) if self._config else None + args["config"] = types.EmbedContentConfig( + **self._config) if self._config else None response = self._client.models.embed_content(**args) @@ -186,7 +189,8 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) + embeddings, meta = self._embed_batch( + texts_to_embed=texts_to_embed, batch_size=self._batch_size) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 21936d2f6..46da7092c 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -142,7 +142,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): def test_prepare_texts_to_embed_w_metadata(self): documents = [ - Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) + Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={ + "meta_field": f"meta_value {i}"}) for i in range(5) ] @@ -152,15 +153,16 @@ def test_prepare_texts_to_embed_w_metadata(self): prepared_texts = embedder._prepare_texts_to_embed(documents) assert prepared_texts == [ - 'meta_value 0 | document number 0:\ncontent', - 'meta_value 1 | document number 1:\ncontent', - 'meta_value 2 | document number 2:\ncontent', - 'meta_value 3 | document number 3:\ncontent', - 'meta_value 4 | document number 4:\ncontent' + "meta_value 0 | document number 0:\ncontent", + "meta_value 1 | document number 1:\ncontent", + "meta_value 2 | document number 2:\ncontent", + "meta_value 3 | document number 3:\ncontent", + "meta_value 4 | document number 4:\ncontent" ] def test_run_wrong_input_format(self): - embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key")) # wrong formats string_input = "text" @@ -173,7 +175,8 @@ def test_run_wrong_input_format(self): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): - embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder( + api_key=Secret.from_token("fake-api-key")) empty_list_input = [] result = embedder.run(documents=empty_list_input) @@ -189,12 +192,14 @@ def test_run_on_empty_list(self): def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + Document(content="A transformer is a deep learning architecture", meta={ + "topic": "ML"}), ] model = "text-embedding-004" - embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") + embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=[ + "topic"], embedding_separator=" | ") result = embedder.run(documents=docs) documents_with_embeddings = result["documents"] From f8e5f8a893b0441fd6a29859727ed22e694a6224 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:42:31 +0530 Subject: [PATCH 12/52] refactor: improve code formatting for consistency and readability in document embedder and tests --- .../google_genai/document_embedder.py | 15 ++++------ .../tests/test_document_embedder.py | 29 +++++++------------ .../google_genai/tests/test_text_embedder.py | 6 ++-- 3 files changed, 19 insertions(+), 31 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index f38082b33..3da4aec8a 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -85,8 +85,7 @@ def __init__( self._meta_fields_to_embed = meta_fields_to_embed or [] self._embedding_separator = embedding_separator self._client = genai.Client(api_key=api_key.resolve_value()) - self._config = config if config is not None else { - "task_type": "SEMANTIC_SIMILARITY"} + self._config = config if config is not None else {"task_type": "SEMANTIC_SIMILARITY"} def to_dict(self) -> Dict[str, Any]: """ @@ -134,8 +133,7 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: ] text_to_embed = ( - self._prefix + self._embedding_separator.join( - [*meta_values_to_embed, doc.content or ""]) + self._suffix + self._prefix + self._embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self._suffix ) texts_to_embed.append(text_to_embed) @@ -151,11 +149,9 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List for batch in tqdm( batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" ): - args: Dict[str, Any] = {"model": self._model, - "contents": [b[1] for b in batch]} + args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} if self._config: - args["config"] = types.EmbedContentConfig( - **self._config) if self._config else None + args["config"] = types.EmbedContentConfig(**self._config) if self._config else None response = self._client.models.embed_content(**args) @@ -189,8 +185,7 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( - texts_to_embed=texts_to_embed, batch_size=self._batch_size) + embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self._batch_size) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 46da7092c..31e55baf4 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -92,8 +92,7 @@ def test_to_dict(self, monkeypatch): data = component.to_dict() assert data == { "type": ( - "haystack_integrations.components.embedders" - ".google_genai.document_embedder.GoogleGenAIDocumentEmbedder" + "haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder" ), "init_parameters": { "model": "text-embedding-004", @@ -124,8 +123,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): data = component.to_dict() assert data == { "type": ( - "haystack_integrations.components.embedders" - ".google_genai.document_embedder.GoogleGenAIDocumentEmbedder" + "haystack_integrations.components.embedders.google_genai.document_embedder.GoogleGenAIDocumentEmbedder" ), "init_parameters": { "model": "model", @@ -142,8 +140,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): def test_prepare_texts_to_embed_w_metadata(self): documents = [ - Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={ - "meta_field": f"meta_value {i}"}) + Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) ] @@ -157,12 +154,11 @@ def test_prepare_texts_to_embed_w_metadata(self): "meta_value 1 | document number 1:\ncontent", "meta_value 2 | document number 2:\ncontent", "meta_value 3 | document number 3:\ncontent", - "meta_value 4 | document number 4:\ncontent" + "meta_value 4 | document number 4:\ncontent", ] def test_run_wrong_input_format(self): - embedder = GoogleGenAIDocumentEmbedder( - api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) # wrong formats string_input = "text" @@ -175,8 +171,7 @@ def test_run_wrong_input_format(self): embedder.run(documents=list_integers_input) def test_run_on_empty_list(self): - embedder = GoogleGenAIDocumentEmbedder( - api_key=Secret.from_token("fake-api-key")) + embedder = GoogleGenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) empty_list_input = [] result = embedder.run(documents=empty_list_input) @@ -192,14 +187,12 @@ def test_run_on_empty_list(self): def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={ - "topic": "ML"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), ] model = "text-embedding-004" - embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=[ - "topic"], embedding_separator=" | ") + embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") result = embedder.run(documents=docs) documents_with_embeddings = result["documents"] @@ -211,6 +204,6 @@ def test_run(self): assert len(doc.embedding) == 768 assert all(isinstance(x, float) for x in doc.embedding) - assert ( - "text" in result["meta"]["model"] and "004" in result["meta"]["model"] - ), "The model name does not contain 'text' and '004'" + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py index e5af84f66..bb700527b 100644 --- a/integrations/google_genai/tests/test_text_embedder.py +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -160,6 +160,6 @@ def test_run(self): assert len(result["embedding"]) == 768 assert all(isinstance(x, float) for x in result["embedding"]) - assert ( - "text" in result["meta"]["model"] and "004" in result["meta"]["model"] - ), "The model name does not contain 'text' and '004'" + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) From 666d0d51f6b2190d8f501538b5f01f4ff1d9f2b2 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 18:46:29 +0530 Subject: [PATCH 13/52] refactor: update _prepare_texts_to_embed to return a list instead of a dictionary --- .../components/embedders/google_genai/document_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 3da4aec8a..4f143a07e 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -120,7 +120,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleGenAIDocumentEmbedder": deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) - def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]: + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. """ From 0b8d687f3c952557a795f28a5211a50b6c8c1d2d Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 6 Jun 2025 19:44:21 +0530 Subject: [PATCH 14/52] feat: add new author to project metadata in pyproject.toml --- integrations/google_genai/pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/integrations/google_genai/pyproject.toml b/integrations/google_genai/pyproject.toml index e3c94cc60..65c8797ea 100644 --- a/integrations/google_genai/pyproject.toml +++ b/integrations/google_genai/pyproject.toml @@ -10,7 +10,10 @@ readme = "README.md" requires-python = ">=3.9" license = "Apache-2.0" keywords = [] -authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Gary Badwal", email = "gurpreet071999@gmail.com" } +] classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", From 706de87d9a245370d45abf4f98567ad89659205b Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 21 Jun 2025 10:10:34 +0530 Subject: [PATCH 15/52] feat: add asynchronous embedding methods for GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder --- .../google_genai/document_embedder.py | 53 +++++++++++++++++++ .../embedders/google_genai/text_embedder.py | 20 +++++++ .../tests/test_document_embedder.py | 30 +++++++++++ .../google_genai/tests/test_text_embedder.py | 19 +++++++ 4 files changed, 122 insertions(+) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 4f143a07e..61175ed0c 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -162,6 +162,30 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List meta["model"] = self._model return all_embeddings, meta + + async def _embed_batch_async(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Embed a list of texts in batches asynchronously. + """ + + all_embeddings = [] + meta: Dict[str, Any] = {} + for batch in tqdm( + batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings" + ): + args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]} + if self._config: + args["config"] = types.EmbedContentConfig(**self._config) if self._config else None + + response = await self._client.aio.models.embed_content(**args) + + embeddings = [el.values for el in response.embeddings] + all_embeddings.extend(embeddings) + + if "model" not in meta: + meta["model"] = self._model + + return all_embeddings, meta @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: @@ -191,3 +215,32 @@ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict doc.embedding = emb return {"documents": documents, "meta": meta} + + @component.output_types(documents=List[Document], meta=Dict[str, Any]) + async def run_async(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: + """ + Embeds a list of documents asynchronously. + + :param documents: + A list of documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents with embeddings. + - `meta`: Information about the usage of the model. + """ + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): + error_message_documents = ( + "GoogleGenAIDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the GoogleGenAITextEmbedder." + ) + raise TypeError(error_message_documents) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self._batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "meta": meta} \ No newline at end of file diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index 415d5fc21..bde0a20ce 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -137,3 +137,23 @@ def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]: create_kwargs = self._prepare_input(text=text) response = self._client.models.embed_content(**create_kwargs) return self._prepare_output(result=response) + + @component.output_types(embedding=List[float], meta=Dict[str, Any]) + async def run_async(self, text: str): + """ + Asynchronously embed a single string. + + This is the asynchronous version of the `run` method. It has the same parameters and return values + but can be used with `await` in async code. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + create_kwargs = self._prepare_input(text=text) + response = await self._client.aio.models.embed_content(**create_kwargs) + return self._prepare_output(result=response) \ No newline at end of file diff --git a/integrations/google_genai/tests/test_document_embedder.py b/integrations/google_genai/tests/test_document_embedder.py index 31e55baf4..9511eee5b 100644 --- a/integrations/google_genai/tests/test_document_embedder.py +++ b/integrations/google_genai/tests/test_document_embedder.py @@ -207,3 +207,33 @@ def test_run(self): assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( "The model name does not contain 'text' and '004'" ) + + @pytest.mark.asyncio + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) + @pytest.mark.integration + async def test_run_async(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "text-embedding-004" + + embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") + + result = await embedder.run_async(documents=docs) + documents_with_embeddings = result["documents"] + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 768 + assert all(isinstance(x, float) for x in doc.embedding) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) \ No newline at end of file diff --git a/integrations/google_genai/tests/test_text_embedder.py b/integrations/google_genai/tests/test_text_embedder.py index bb700527b..cb0f4e8b1 100644 --- a/integrations/google_genai/tests/test_text_embedder.py +++ b/integrations/google_genai/tests/test_text_embedder.py @@ -163,3 +163,22 @@ def test_run(self): assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( "The model name does not contain 'text' and '004'" ) + + @pytest.mark.asyncio + @pytest.mark.skipif( + not os.environ.get("GOOGLE_API_KEY", None), + reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.", + ) + @pytest.mark.integration + async def test_run_async(self): + model = "text-embedding-004" + + embedder = GoogleGenAITextEmbedder(model=model) + result = await embedder.run_async(text="The food was delicious") + + assert len(result["embedding"]) == 768 + assert all(isinstance(x, float) for x in result["embedding"]) + + assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], ( + "The model name does not contain 'text' and '004'" + ) \ No newline at end of file From 818e40dfddcbaa62a64e416d6bec3992b0831a5d Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 21 Jun 2025 10:30:43 +0530 Subject: [PATCH 16/52] fix: ensure consistent formatting for pylint --- .../embedders/google_genai/document_embedder.py | 8 +++++--- .../components/embedders/google_genai/text_embedder.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 3987f87ed..e5f4b0b15 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -170,8 +170,10 @@ def _embed_batch( meta["model"] = self._model return all_embeddings, meta - - async def _embed_batch_async(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + + async def _embed_batch_async( + self, texts_to_embed: List[str], batch_size: int + ) -> Tuple[List[List[float]], Dict[str, Any]]: """ Embed a list of texts in batches asynchronously. """ @@ -252,4 +254,4 @@ async def run_async(self, documents: List[Document]) -> Dict[str, Union[List[Doc for doc, emb in zip(documents, embeddings): doc.embedding = emb - return {"documents": documents, "meta": meta} \ No newline at end of file + return {"documents": documents, "meta": meta} diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index 7da6325ee..9ce93ab81 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -157,4 +157,4 @@ async def run_async(self, text: str): """ create_kwargs = self._prepare_input(text=text) response = await self._client.aio.models.embed_content(**create_kwargs) - return self._prepare_output(result=response) \ No newline at end of file + return self._prepare_output(result=response) From 4909c014a0fa5048d3095807b7dd07725bca5990 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 21 Jun 2025 10:34:18 +0530 Subject: [PATCH 17/52] fix: update return type annotation for run_async method in GoogleGenAIDocumentEmbedder --- .../components/embedders/google_genai/document_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index e5f4b0b15..e67ab23e3 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -228,7 +228,7 @@ def run(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dic return {"documents": documents, "meta": meta} @component.output_types(documents=List[Document], meta=Dict[str, Any]) - async def run_async(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]: + async def run_async(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dict[str, Any]]: """ Embeds a list of documents asynchronously. From 8f6f9f1a9c033cd16b80bd992972e812284f5a86 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Sat, 21 Jun 2025 10:37:06 +0530 Subject: [PATCH 18/52] fix: update return type annotation for run_async method in GoogleGenAITextEmbedder --- .../components/embedders/google_genai/text_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py index 9ce93ab81..d61e49fff 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py @@ -140,7 +140,7 @@ def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]: return self._prepare_output(result=response) @component.output_types(embedding=List[float], meta=Dict[str, Any]) - async def run_async(self, text: str): + async def run_async(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]: """ Asynchronously embed a single string. From 60a4506183531d50cd8571156db027fd0dd1405e Mon Sep 17 00:00:00 2001 From: garybadwal Date: Tue, 24 Jun 2025 12:01:44 +0530 Subject: [PATCH 19/52] fix: update return type annotation and handle None values in _embed_batch_async method --- .../embedders/google_genai/document_embedder.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index e67ab23e3..1808200c2 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -173,7 +173,7 @@ def _embed_batch( async def _embed_batch_async( self, texts_to_embed: List[str], batch_size: int - ) -> Tuple[List[List[float]], Dict[str, Any]]: + ) -> Tuple[List[Optional[List[float]]], Dict[str, Any]]: """ Embed a list of texts in batches asynchronously. """ @@ -188,9 +188,14 @@ async def _embed_batch_async( args["config"] = types.EmbedContentConfig(**self._config) if self._config else None response = await self._client.aio.models.embed_content(**args) - - embeddings = [el.values for el in response.embeddings] - all_embeddings.extend(embeddings) + + embeddings = [] + if response.embeddings: + for el in response.embeddings: + embeddings.append(el.values if el.values else None) + all_embeddings.extend(embeddings) + else: + all_embeddings.extend([None] * len(batch)) if "model" not in meta: meta["model"] = self._model From 8734e264a9e8f13fb5f4716c23b451f2a866605c Mon Sep 17 00:00:00 2001 From: garybadwal Date: Tue, 24 Jun 2025 12:04:56 +0530 Subject: [PATCH 20/52] fix: remove unnecessary blank line in _embed_batch_async method --- .../components/embedders/google_genai/document_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py index 1808200c2..595851493 100644 --- a/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py +++ b/integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py @@ -188,7 +188,7 @@ async def _embed_batch_async( args["config"] = types.EmbedContentConfig(**self._config) if self._config else None response = await self._client.aio.models.embed_content(**args) - + embeddings = [] if response.embeddings: for el in response.embeddings: From 15078ea8c21bb994c6975458322783444af547d7 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 27 Sep 2025 09:54:31 +0530 Subject: [PATCH 21/52] feat: add CometAPI integration with CometAPIChatGenerator class and configuration --- integrations/cometapi/pydoc/config.yml | 29 ++++++++++ .../generators/cometapi/__init__.py | 7 +++ .../cometapi/chat/chat_generator.py | 56 +++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 integrations/cometapi/pydoc/config.yml create mode 100644 integrations/cometapi/src/haystack_integrations/components/generators/cometapi/__init__.py create mode 100644 integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py diff --git a/integrations/cometapi/pydoc/config.yml b/integrations/cometapi/pydoc/config.yml new file mode 100644 index 000000000..4dbb825bd --- /dev/null +++ b/integrations/cometapi/pydoc/config.yml @@ -0,0 +1,29 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.generators.cometapi.chat.chat_generator", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Comet API integration for Haystack + category_slug: integrations-api + title: Comet API + slug: integrations-cometapi + order: 91 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_cometapi.md \ No newline at end of file diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/__init__.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/__init__.py new file mode 100644 index 000000000..98f0ce06a --- /dev/null +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat.chat_generator import CometAPIChatGenerator + +__all__ = ["CometAPIChatGenerator"] diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py new file mode 100644 index 000000000..1cd92dd3e --- /dev/null +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -0,0 +1,56 @@ +from typing import Optional, List, Dict, Any, Callable + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.utils import Secret + +class CometAPIChatGenerator(OpenAIChatGenerator): + + """ + A chat generator that uses the CometAPI for generating chat responses. + + This class extends Haystack's OpenAIChatGenerator to specifically interact with the CometAPI. + It sets the `api_base_url` to the CometAPI endpoint and allows for all the + standard configurations available in the OpenAIChatGenerator. + + :param api_key: The API key for authenticating with the CometAPI. Defaults to + loading from the "COMET_API_KEY" environment variable. + :param model: The name of the model to use for chat generation (e.g., "gpt-4o-mini", "grok-3-mini"). + Defaults to "gpt-4o-mini". + :param streaming_callback: An optional callable that will be called with each chunk of + a streaming response. + :param generation_kwargs: Optional keyword arguments to pass to the underlying generation + API call. + :param timeout: The maximum time in seconds to wait for a response from the API. + :param max_retries: The maximum number of times to retry a failed API request. + :param tools: An optional list of tool definitions that the model can use. + :param tools_strict: If True, the model is forced to use one of the provided tools if a tool call is made. + :param http_client_kwargs: Optional keyword arguments to pass to the HTTP client. + """ + + def __init__( + self, + api_key = Secret.from_env_var("COMET_API_KEY"), + model = "gpt-4o-mini", + streaming_callback: Optional[Callable[[str], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = None, + max_retries: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tools_strict: bool = False, + http_client_kwargs: Optional[Dict[str, Any]] = None + ): + + api_base_url = "https://api.cometapi.com/v1" + + super().__init__( + api_key=api_key, + model=model, + api_base_url=api_base_url, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + timeout=timeout, + max_retries=max_retries, + tools=tools, + tools_strict=tools_strict, + http_client_kwargs=http_client_kwargs + ) From 6c8c6413fb1cffdb945ecdfa87384962b3aba563 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Thu, 9 Oct 2025 19:45:26 +0530 Subject: [PATCH 22/52] feat: add CometAPI integration with chat generator and tests --- integrations/cometapi/LICENSE.txt | 192 +++++++++++++ integrations/cometapi/README.md | 42 +++ integrations/cometapi/pyproject.toml | 189 +++++++++++++ .../tests/test_cometapi_chat_generator.py | 265 ++++++++++++++++++ 4 files changed, 688 insertions(+) create mode 100644 integrations/cometapi/LICENSE.txt create mode 100644 integrations/cometapi/README.md create mode 100644 integrations/cometapi/pyproject.toml create mode 100644 integrations/cometapi/tests/test_cometapi_chat_generator.py diff --git a/integrations/cometapi/LICENSE.txt b/integrations/cometapi/LICENSE.txt new file mode 100644 index 000000000..c81963001 --- /dev/null +++ b/integrations/cometapi/LICENSE.txt @@ -0,0 +1,192 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (which shall not include communications that are marked or + designated in writing by the copyright owner as "Not a Contribution"). + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control + systems, and issue tracking systems that are managed by, or on behalf + of, the Licensor for the purpose of discussing and improving the Work, + but excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution". + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to use, reproduce, modify, distribute, and prepare + Derivative Works of, publicly display, publicly perform, sublicense, + and distribute the Work and such Derivative Works in Source or Object + form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, trademark, patent, + attribution and other notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright notice to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Support. You may choose to offer, and to + charge a fee for, warranty, support, indemnity or other liability + obligations and/or rights consistent with this License. However, in + accepting such obligations, You may act only on Your own behalf and on + Your sole responsibility, not on behalf of any other Contributor, and + only if You agree to indemnify, defend, and hold each Contributor + harmless for any liability incurred by, or claims asserted against, + such Contributor by reason of your accepting any such warranty or support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same page as the copyright notice for easier identification within + third-party archives. + + Copyright 2023-present deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/integrations/cometapi/README.md b/integrations/cometapi/README.md new file mode 100644 index 000000000..87edeb7a9 --- /dev/null +++ b/integrations/cometapi/README.md @@ -0,0 +1,42 @@ +# Comet API Haystack Integration + +[![PyPI - Version](https://img.shields.io/pypi/v/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [Usage](#usage) +- [License](#license) + +## Installation + +```console +pip install cometapi-haystack +``` + +## Usage + +This integration provides components to use models via the new Comet APIs. + +### Chat Generator + +```python +from haystack.dataclasses.chat_message import ChatMessage +from haystack_integrations.components.generators.cometapi import CometAPIChatGenerator + +# Initialize the chat generator +chat_generator = CometAPIChatGenerator(model="grok-3-mini") + +# Generate a response +messages = [ChatMessage.from_user("Tell me about the future of AI")] +response = chat_generator.run(messages=messages) +print(response["replies"][0].text) +``` + + +## License + +`cometapi-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file diff --git a/integrations/cometapi/pyproject.toml b/integrations/cometapi/pyproject.toml new file mode 100644 index 000000000..900975ec3 --- /dev/null +++ b/integrations/cometapi/pyproject.toml @@ -0,0 +1,189 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "cometapi-haystack" +dynamic = ["version"] +description = 'Use Comet API with Haystack to build AI applications with 500+ AI models.' +readme = "README.md" +requires-python = ">=3.9" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Gary Badwal", email = "gurpreet071999@gmail.com" } +] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai>=2.13.2",] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cometapi#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cometapi" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/cometapi-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/cometapi-v[0-9]*"' + +[tool.hatch.envs.default] +installer = "uv" +dependencies = ["haystack-pydoc-tools", "ruff"] + +[tool.hatch.envs.default.scripts] +docs = ["pydoc-markdown pydoc/config.yml"] +fmt = "ruff check --fix {args} && ruff format {args}" +fmt-check = "ruff check {args} && ruff format --check {args}" + +[tool.hatch.envs.test] +dependencies = [ + "pytest", + "pytest-asyncio", + "pytest-cov", + "pytest-rerunfailures", + "mypy", + "pip" +] + +[tool.hatch.envs.test.scripts] +unit = 'pytest -m "not integration" {args:tests}' +integration = 'pytest -m "integration" {args:tests}' +all = 'pytest {args:tests}' +cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x' + +types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" + +# TODO: remove lint environment once this integration is properly typed +# test environment should be used instead +# https://github.com/deepset-ai/haystack-core-integrations/issues/1771 +[tool.hatch.envs.lint] +installer = "uv" +detached = true +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "more-itertools"] + +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Ignore unused params + "ARG001", + "ARG002", + "ARG005", + # Allow function call argument defaults e.g. `Secret.from_env_var` + "B008", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.lint.isort] +known-first-party = ["haystack_integrations"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] +# Examples can use print statements +"examples/**/*" = ["T201"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = true + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pytest.*", + "numpy.*", + "jsonref.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "unit: unit tests", + "integration: integration tests", + "generators: generators tests", +] +log_cli = true +asyncio_mode = "auto" diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py new file mode 100644 index 000000000..cbb19969d --- /dev/null +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -0,0 +1,265 @@ +import os +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest +import pytz +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + StreamingChunk, +) +from haystack.tools import Tool +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + +from haystack_integrations.components.generators.cometapi.chat.chat_generator import ( + CometAPIChatGenerator, +) + +pytestmark = pytest.mark.asyncio + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France"), + ] + + +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + +@pytest.fixture +def tools(): + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather, + ) + + return [tool] + + +@pytest.fixture +def mock_async_chat_completion(): + """ + Mock the Async OpenAI API completion response and reuse it for async tests + """ + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="openai/gpt-4o-mini", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="Hello world!", role="assistant"), + ) + ], + created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), + usage={ + "prompt_tokens": 57, + "completion_tokens": 40, + "total_tokens": 97, + }, + ) + # For async mocks, the return value should be awaitable + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + +class TestCometAPIChatGeneratorAsync: + def test_init_default_async(self, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "test-api-key") + component = CometAPIChatGenerator() + + assert isinstance(component.async_client, AsyncOpenAI) + assert component.async_client.api_key == "test-api-key" + assert component.async_client.base_url == "https://api.cometapi.com/v1/" + assert not component.generation_kwargs + + @pytest.mark.asyncio + async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): # noqa: ARG002 + monkeypatch.setenv("COMET_API_KEY", "fake-api-key") + component = CometAPIChatGenerator() + response = await component.run_async(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.asyncio + async def test_run_async_with_params(self, chat_messages, mock_async_chat_completion, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "fake-api-key") + component = CometAPIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) + response = await component.run_async(chat_messages) + + # check that the component calls the OpenAI API with the correct parameters + _, kwargs = mock_async_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + chat_messages = [ChatMessage.from_user("What's the capital of France")] + component = CometAPIChatGenerator() + results = await component.run_async(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "openai/gpt-4o-mini" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_streaming_async(self): + counter = 0 + responses = "" + + async def callback(chunk: StreamingChunk): + nonlocal counter + nonlocal responses + counter += 1 + responses += chunk.content if chunk.content else "" + + component = CometAPIChatGenerator(streaming_callback=callback) + results = await component.run_async([ChatMessage.from_user("What's the capital of France?")]) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "openai/gpt-4o-mini" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + assert counter > 1 + assert "Paris" in responses + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_with_tools_and_response_async(self, tools): + """ + Integration test that the CometAPIChatGenerator component can run with tools and get a response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = CometAPIChatGenerator(tools=tools) + results = await component.run_async(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = await component.run_async(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_with_tools_streaming_async(self, tools): + """ + Integration test that the CometAPIChatGenerator component can run with tools and streaming. + """ + + counter = 0 + tool_calls = [] + + async def callback(chunk: StreamingChunk): + nonlocal counter + nonlocal tool_calls + counter += 1 + if chunk.meta.get("tool_calls"): + tool_calls.extend(chunk.meta["tool_calls"]) + + component = CometAPIChatGenerator(tools=tools, streaming_callback=callback) + results = await component.run_async( + [ChatMessage.from_user("What's the weather like in Paris?")], + generation_kwargs={"tool_choice": "auto"}, + ) + + assert len(results["replies"]) > 0, "No replies received" + assert counter > 1, "Streaming callback was not called multiple times" + assert tool_calls, "No tool calls received in streaming" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" From 6f0d6954847787e2b25dcb673a1e71709aa003cb Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 15:44:55 +0530 Subject: [PATCH 23/52] feat: add CometAPI integration with workflow and labeler configuration --- .github/labeler.yml | 5 ++ .github/workflows/cometapi.yml | 83 ++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 .github/workflows/cometapi.yml diff --git a/.github/labeler.yml b/.github/labeler.yml index 83c5680e1..7c612ed20 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -39,6 +39,11 @@ integration:cohere: - any-glob-to-any-file: "integrations/cohere/**/*" - any-glob-to-any-file: ".github/workflows/cohere.yml" +integration:cometapi: + - changed-files: + - any-glob-to-any-file: "integrations/cometapi/**/*" + - any-glob-to-any-file: ".github/workflows/cometapi.yml" + integration:deepeval: - changed-files: - any-glob-to-any-file: "integrations/deepeval/**/*" diff --git a/.github/workflows/cometapi.yml b/.github/workflows/cometapi.yml new file mode 100644 index 000000000..beb9a5eb9 --- /dev/null +++ b/.github/workflows/cometapi.yml @@ -0,0 +1,83 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / cometapi + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/cometapi/**" + - "!integrations/cometapi/*.md" + - ".github/workflows/cometapi.yml" + +defaults: + run: + working-directory: integrations/cometapi + +concurrency: + group: cometapi-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + COMET_API_KEY: "${{ secrets.COMET_API_KEY }}" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.13"] + max-parallel: 2 # to avoid "429 Resource has been exhausted" + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v5 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run fmt-check && hatch run test:types + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run test:cov-retry + + - name: Run unit tests with lowest direct dependencies + run: | + hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt + hatch run uv pip install -r requirements_lowest_direct.txt + hatch run test:unit + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch env prune + hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main + hatch run test:unit + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} \ No newline at end of file From c122e2e7f410a44659f887fc5dc1c9c1b4ddfec4 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 15:48:17 +0530 Subject: [PATCH 24/52] feat: add CometAPI integration to the inventory in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ef5912673..e4463d729 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [azure-ai-search-haystack](integrations/azure_ai_search/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) | [![Test / azure-ai-search](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/azure_ai_search.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/azure_ai_search.yml) | | [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | | [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [cometapi-haystack](integrations/cometapi/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) | [![Test / cometapi](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cometapi.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cometapi.yml) | | [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | | [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | From e8fea278402bb521c44eebbfd701fa75e8e5bfde Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:25:51 +0530 Subject: [PATCH 25/52] feat: refactor CometAPIChatGenerator initialization and update test for async run --- .../cometapi/chat/chat_generator.py | 42 ++++++++++--------- .../tests/test_cometapi_chat_generator.py | 2 +- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 1cd92dd3e..002fb8340 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -1,8 +1,9 @@ -from typing import Optional, List, Dict, Any, Callable +from typing import Any, Callable, Dict, List, Optional from haystack.components.generators.chat import OpenAIChatGenerator from haystack.utils import Secret + class CometAPIChatGenerator(OpenAIChatGenerator): """ @@ -28,29 +29,30 @@ class CometAPIChatGenerator(OpenAIChatGenerator): """ def __init__( - self, - api_key = Secret.from_env_var("COMET_API_KEY"), - model = "gpt-4o-mini", - streaming_callback: Optional[Callable[[str], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, - max_retries: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, - tools_strict: bool = False, + self, + api_key = Secret.from_env_var("COMET_API_KEY"), + model = "gpt-4o-mini", + streaming_callback: Optional[Callable[[str], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = None, + max_retries: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + *, + tools_strict: bool = False, http_client_kwargs: Optional[Dict[str, Any]] = None ): api_base_url = "https://api.cometapi.com/v1" - + super().__init__( - api_key=api_key, - model=model, - api_base_url=api_base_url, - streaming_callback=streaming_callback, - generation_kwargs=generation_kwargs, - timeout=timeout, - max_retries=max_retries, - tools=tools, - tools_strict=tools_strict, + api_key=api_key, + model=model, + api_base_url=api_base_url, + streaming_callback=streaming_callback, + generation_kwargs=generation_kwargs, + timeout=timeout, + max_retries=max_retries, + tools=tools, + tools_strict=tools_strict, http_client_kwargs=http_client_kwargs ) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index cbb19969d..925a001eb 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -94,7 +94,7 @@ def test_init_default_async(self, monkeypatch): assert not component.generation_kwargs @pytest.mark.asyncio - async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): # noqa: ARG002 + async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "fake-api-key") component = CometAPIChatGenerator() response = await component.run_async(chat_messages) From e16da5f6aea14a4494e5e7e0e2fb8f6511938c05 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:29:24 +0530 Subject: [PATCH 26/52] feat: add pytz dependency to test environment --- integrations/cometapi/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/cometapi/pyproject.toml b/integrations/cometapi/pyproject.toml index 900975ec3..922732791 100644 --- a/integrations/cometapi/pyproject.toml +++ b/integrations/cometapi/pyproject.toml @@ -55,6 +55,7 @@ fmt-check = "ruff check {args} && ruff format --check {args}" [tool.hatch.envs.test] dependencies = [ + "pytz", "pytest", "pytest-asyncio", "pytest-cov", From 052bac74606fa093299d9b929818bd27e9311ab6 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:34:37 +0530 Subject: [PATCH 27/52] feat: clean up CometAPIChatGenerator constructor formatting and add blank line in test file --- .../cometapi/chat/chat_generator.py | 28 +++++++++---------- .../tests/test_cometapi_chat_generator.py | 1 + 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 002fb8340..a45e220c0 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -5,7 +5,6 @@ class CometAPIChatGenerator(OpenAIChatGenerator): - """ A chat generator that uses the CometAPI for generating chat responses. @@ -29,19 +28,18 @@ class CometAPIChatGenerator(OpenAIChatGenerator): """ def __init__( - self, - api_key = Secret.from_env_var("COMET_API_KEY"), - model = "gpt-4o-mini", - streaming_callback: Optional[Callable[[str], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, - max_retries: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, - *, - tools_strict: bool = False, - http_client_kwargs: Optional[Dict[str, Any]] = None - ): - + self, + api_key=Secret.from_env_var("COMET_API_KEY"), + model="gpt-4o-mini", + streaming_callback: Optional[Callable[[str], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = None, + max_retries: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + *, + tools_strict: bool = False, + http_client_kwargs: Optional[Dict[str, Any]] = None, + ): api_base_url = "https://api.cometapi.com/v1" super().__init__( @@ -54,5 +52,5 @@ def __init__( max_retries=max_retries, tools=tools, tools_strict=tools_strict, - http_client_kwargs=http_client_kwargs + http_client_kwargs=http_client_kwargs, ) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 925a001eb..6d0e5fd0c 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -20,6 +20,7 @@ pytestmark = pytest.mark.asyncio + @pytest.fixture def chat_messages(): return [ From 9191c7ee635f35ddbb08abf6e181f8da430c6bcc Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:41:24 +0530 Subject: [PATCH 28/52] feat: update type hints for streaming_callback and tools in CometAPIChatGenerator --- .../generators/cometapi/chat/chat_generator.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index a45e220c0..4c4f9ba71 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -1,6 +1,8 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.nodes import Tool, Toolset +from haystack.schema import StreamingChunk from haystack.utils import Secret @@ -31,11 +33,16 @@ def __init__( self, api_key=Secret.from_env_var("COMET_API_KEY"), model="gpt-4o-mini", - streaming_callback: Optional[Callable[[str], None]] = None, + streaming_callback: Optional[ + Union[ + Callable[[StreamingChunk], None], + Callable[[StreamingChunk], Awaitable[None]], + ] + ] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, max_retries: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, + tools: Optional[Union[List[Tool], Toolset]] = None, *, tools_strict: bool = False, http_client_kwargs: Optional[Dict[str, Any]] = None, From 585afa24e07f9cb9f0de34876cfd28f29948f0f1 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:48:09 +0530 Subject: [PATCH 29/52] feat: update streaming_callback type hint to use StreamingCallbackT for CometAPIChatGenerator --- .../generators/cometapi/chat/chat_generator.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 4c4f9ba71..64645668d 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -1,8 +1,8 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Union from haystack.components.generators.chat import OpenAIChatGenerator -from haystack.nodes import Tool, Toolset -from haystack.schema import StreamingChunk +from haystack.dataclasses import StreamingCallbackT +from haystack.tools import Tool, Toolset from haystack.utils import Secret @@ -33,12 +33,7 @@ def __init__( self, api_key=Secret.from_env_var("COMET_API_KEY"), model="gpt-4o-mini", - streaming_callback: Optional[ - Union[ - Callable[[StreamingChunk], None], - Callable[[StreamingChunk], Awaitable[None]], - ] - ] = None, + streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, max_retries: Optional[int] = None, From a455c8c47cf588e74a2009e172ffa615e61b1d27 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 16:50:30 +0530 Subject: [PATCH 30/52] feat: remove unused Awaitable and Callable imports in chat_generator.py --- .../components/generators/cometapi/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 64645668d..0bfa895da 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -1,4 +1,4 @@ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingCallbackT From 63cadfa211a7c0041c82b49ba4bc40391cc2e283 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 17:01:21 +0530 Subject: [PATCH 31/52] feat: add initial Changelog file for CometAPI integration --- integrations/cometapi/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 integrations/cometapi/CHANGELOG.md diff --git a/integrations/cometapi/CHANGELOG.md b/integrations/cometapi/CHANGELOG.md new file mode 100644 index 000000000..825c32f0d --- /dev/null +++ b/integrations/cometapi/CHANGELOG.md @@ -0,0 +1 @@ +# Changelog From e4e50faba22dcbbeed2a236b8327da8a3bf3299b Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 17:48:32 +0530 Subject: [PATCH 32/52] feat: add pytz dependency to test environment and update mypy configuration --- integrations/cometapi/pyproject.toml | 53 ++++--------------- .../components/generators/py.typed | 0 2 files changed, 11 insertions(+), 42 deletions(-) create mode 100644 integrations/cometapi/src/haystack_integrations/components/generators/py.typed diff --git a/integrations/cometapi/pyproject.toml b/integrations/cometapi/pyproject.toml index 922732791..8022b87f8 100644 --- a/integrations/cometapi/pyproject.toml +++ b/integrations/cometapi/pyproject.toml @@ -55,13 +55,13 @@ fmt-check = "ruff check {args} && ruff format --check {args}" [tool.hatch.envs.test] dependencies = [ - "pytz", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "mypy", - "pip" + "pip", + "pytz" ] [tool.hatch.envs.test.scripts] @@ -70,23 +70,13 @@ integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x' -types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +types = "mypy -p haystack_integrations.components.generators.cometapi {args}" -# TODO: remove lint environment once this integration is properly typed -# test environment should be used instead -# https://github.com/deepset-ai/haystack-core-integrations/issues/1771 -[tool.hatch.envs.lint] -installer = "uv" -detached = true -dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "more-itertools"] - -[tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" - -[tool.black] -target-version = ["py38"] -line-length = 120 -skip-string-normalization = true +[tool.mypy] +install_types = true +non_interactive = true +check_untyped_defs = true +disallow_incomplete_defs = true [tool.ruff] target-version = "py38" @@ -102,7 +92,6 @@ select = [ "E", "EM", "F", - "FBT", "I", "ICN", "ISC", @@ -123,8 +112,6 @@ select = [ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", # Ignore checks for possible passwords "S105", "S106", @@ -135,12 +122,9 @@ ignore = [ "PLR0912", "PLR0913", "PLR0915", - # Ignore unused params - "ARG001", - "ARG002", - "ARG005", - # Allow function call argument defaults e.g. `Secret.from_env_var` + # Misc "B008", + "S101", ] unfixable = [ # Don't touch unused imports @@ -156,35 +140,20 @@ ban-relative-imports = "parents" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] -# Examples can use print statements -"examples/**/*" = ["T201"] [tool.coverage.run] source = ["haystack_integrations"] branch = true -parallel = true +parallel = false [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] -[[tool.mypy.overrides]] -module = [ - "haystack.*", - "haystack_integrations.*", - "pytest.*", - "numpy.*", - "jsonref.*", -] -ignore_missing_imports = true - [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ - "unit: unit tests", "integration: integration tests", - "generators: generators tests", ] log_cli = true -asyncio_mode = "auto" diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/py.typed b/integrations/cometapi/src/haystack_integrations/components/generators/py.typed new file mode 100644 index 000000000..e69de29bb From 47c4766b66bdd04c9cec9331850280a8218af97b Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 17:50:52 +0530 Subject: [PATCH 33/52] feat: remove unused mock_async_chat_completion fixture from test_run_async --- integrations/cometapi/tests/test_cometapi_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 6d0e5fd0c..a1c38a49b 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -95,7 +95,7 @@ def test_init_default_async(self, monkeypatch): assert not component.generation_kwargs @pytest.mark.asyncio - async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): + async def test_run_async(self, chat_messages, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "fake-api-key") component = CometAPIChatGenerator() response = await component.run_async(chat_messages) From 1cd84d0cfeb57b2c7949f129cb19ecf57792d033 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 18:01:51 +0530 Subject: [PATCH 34/52] feat: update mypy configuration to ignore missing imports for jsonref and add new linting rules --- integrations/cometapi/pyproject.toml | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/integrations/cometapi/pyproject.toml b/integrations/cometapi/pyproject.toml index 8022b87f8..2b31e2288 100644 --- a/integrations/cometapi/pyproject.toml +++ b/integrations/cometapi/pyproject.toml @@ -78,6 +78,12 @@ non_interactive = true check_untyped_defs = true disallow_incomplete_defs = true +[[tool.mypy.overrides]] +module = [ + "jsonref.*", # jsonref does not provide types +] +ignore_missing_imports = true + [tool.ruff] target-version = "py38" line-length = 120 @@ -92,6 +98,7 @@ select = [ "E", "EM", "F", + "FBT", "I", "ICN", "ISC", @@ -112,6 +119,8 @@ select = [ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", # Ignore checks for possible passwords "S105", "S106", @@ -122,9 +131,12 @@ ignore = [ "PLR0912", "PLR0913", "PLR0915", - # Misc + # Ignore unused params + "ARG001", + "ARG002", + "ARG005", + # Allow function call argument defaults e.g. `Secret.from_env_var` "B008", - "S101", ] unfixable = [ # Don't touch unused imports @@ -140,11 +152,13 @@ ban-relative-imports = "parents" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] +# Examples can use print statements +"examples/**/*" = ["T201"] [tool.coverage.run] source = ["haystack_integrations"] branch = true -parallel = false +parallel = true [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] @@ -157,3 +171,4 @@ markers = [ "integration: integration tests", ] log_cli = true +asyncio_mode = "auto" From 44fbcc3ca1d95d015dea7c811fa4b637edc2f5fc Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 18:10:29 +0530 Subject: [PATCH 35/52] feat: add type hints for api_key and model parameters in CometAPIChatGenerator --- .../components/generators/cometapi/chat/chat_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 0bfa895da..e70194fe7 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -31,8 +31,8 @@ class CometAPIChatGenerator(OpenAIChatGenerator): def __init__( self, - api_key=Secret.from_env_var("COMET_API_KEY"), - model="gpt-4o-mini", + api_key: Secret=Secret.from_env_var("COMET_API_KEY"), + model: str="gpt-4o-mini", streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, From 470abc72d21fe1309e5fc423ad8e4c162b6eaacd Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Fri, 10 Oct 2025 18:13:15 +0530 Subject: [PATCH 36/52] style: format parameter definitions in CometAPIChatGenerator constructor --- .../components/generators/cometapi/chat/chat_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index e70194fe7..f8f0cc9c3 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -31,8 +31,8 @@ class CometAPIChatGenerator(OpenAIChatGenerator): def __init__( self, - api_key: Secret=Secret.from_env_var("COMET_API_KEY"), - model: str="gpt-4o-mini", + api_key: Secret = Secret.from_env_var("COMET_API_KEY"), + model: str = "gpt-4o-mini", streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, From 35453f45e3fb188d5173df81b2380debbd9b2843 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 08:54:32 +0530 Subject: [PATCH 37/52] docs: add CometAPI resources section to README --- integrations/cometapi/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/cometapi/README.md b/integrations/cometapi/README.md index 87edeb7a9..30f33dc19 100644 --- a/integrations/cometapi/README.md +++ b/integrations/cometapi/README.md @@ -5,6 +5,12 @@ ----- +**CometAPI Resources** +- [Website](https://www.cometapi.com/?utm_source=haystack&utm_campaign=integration&utm_medium=integration&utm_content=integration) +- [Documentation](https://api.cometapi.com/doc) +- [Get an API Key](https://api.cometapi.com/console/token) +- [Pricing](https://api.cometapi.com/pricing) + **Table of Contents** - [Installation](#installation) From c0ae33b8ec170feaa238f255f02fe5b048593864 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 09:17:16 +0530 Subject: [PATCH 38/52] docs: enhance usage description in README for CometAPI integration --- integrations/cometapi/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/README.md b/integrations/cometapi/README.md index 30f33dc19..58110ec07 100644 --- a/integrations/cometapi/README.md +++ b/integrations/cometapi/README.md @@ -25,7 +25,7 @@ pip install cometapi-haystack ## Usage -This integration provides components to use models via the new Comet APIs. +This integration offers a set of pre-built components that allow developers to interact seamlessly with AI models using the new Comet APIs. ### Chat Generator From 47fdda13d3988efcc3610534959ce1c1494d58c0 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 10:42:34 +0530 Subject: [PATCH 39/52] Refactor CometAPIChatGenerator tests: consolidate async tests into a single file, update to synchronous methods, and enhance test coverage for initialization, running with parameters, and tool integration. --- .../tests/test_cometapi_chat_generator.py | 838 +++++++++++++++--- .../test_cometapi_chat_generator_async.py | 269 ++++++ 2 files changed, 990 insertions(+), 117 deletions(-) create mode 100644 integrations/cometapi/tests/test_cometapi_chat_generator_async.py diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index a1c38a49b..650ccad95 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -1,24 +1,35 @@ import os from datetime import datetime -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest import pytz -from haystack.dataclasses import ( - ChatMessage, - ChatRole, - StreamingChunk, -) +from haystack import Pipeline +from haystack.components.generators.utils import print_streaming_chunk +from haystack.components.tools import ToolInvoker +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall from haystack.tools import Tool -from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from haystack.utils.auth import Secret +from openai import OpenAIError +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction +from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails -from haystack_integrations.components.generators.cometapi.chat.chat_generator import ( - CometAPIChatGenerator, -) +from haystack_integrations.components.generators.cometapi.chat.chat_generator import CometAPIChatGenerator -pytestmark = pytest.mark.asyncio + +class CollectorCallback: + """ + Callback to collect streaming chunks for testing purposes. + """ + + def __init__(self): + self.chunks = [] + + def __call__(self, chunk: StreamingChunk) -> None: + self.chunks.append(chunk) @pytest.fixture @@ -36,11 +47,7 @@ def weather(city: str): @pytest.fixture def tools(): - tool_parameters = { - "type": "object", - "properties": {"city": {"type": "string"}}, - "required": ["city"], - } + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool = Tool( name="weather", description="useful to determine the weather in a given location", @@ -52,17 +59,14 @@ def tools(): @pytest.fixture -def mock_async_chat_completion(): +def mock_chat_completion(): """ - Mock the Async OpenAI API completion response and reuse it for async tests + Mock the OpenAI API completion response and reuse it for tests """ - with patch( - "openai.resources.chat.completions.AsyncCompletions.create", - new_callable=AsyncMock, - ) as mock_chat_completion_create: + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", - model="openai/gpt-4o-mini", + model="gpt-4o-mini", object="chat.completion", choices=[ Choice( @@ -73,32 +77,149 @@ def mock_async_chat_completion(): ) ], created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), - usage={ - "prompt_tokens": 57, - "completion_tokens": 40, - "total_tokens": 97, - }, + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, ) - # For async mocks, the return value should be awaitable + mock_chat_completion_create.return_value = completion yield mock_chat_completion_create -class TestCometAPIChatGeneratorAsync: - def test_init_default_async(self, monkeypatch): +class TestCometAPIChatGenerator: + def test_init_default(self, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "test-api-key") component = CometAPIChatGenerator() - - assert isinstance(component.async_client, AsyncOpenAI) - assert component.async_client.api_key == "test-api-key" - assert component.async_client.base_url == "https://api.cometapi.com/v1/" + assert component.client.api_key == "test-api-key" + assert component.model == "gpt-4o-mini" + assert component.api_base_url == "https://api.cometapi.com/v1" + assert component.streaming_callback is None assert not component.generation_kwargs - @pytest.mark.asyncio - async def test_run_async(self, chat_messages, monkeypatch): + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("COMET_API_KEY", raising=False) + with pytest.raises(ValueError, match=r"None of the .* environment variables are set"): + CometAPIChatGenerator() + + def test_init_with_parameters(self): + component = CometAPIChatGenerator( + api_key=Secret.from_token("test-api-key"), + model="gpt-4o-mini", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + assert component.client.api_key == "test-api-key" + assert component.model == "gpt-4o-mini" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "test-api-key") + component = CometAPIChatGenerator() + data = component.to_dict() + + assert ( + data["type"] + == "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["COMET_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4o-mini", + "streaming_callback": None, + "api_base_url": "https://api.cometapi.com/v1", + "generation_kwargs": {}, + "timeout": None, + "max_retries": None, + "tools": None, + "http_client_kwargs": None, + } + + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("ENV_VAR", "test-api-key") + component = CometAPIChatGenerator( + api_key=Secret.from_env_var("ENV_VAR"), + model="gpt-4o-mini", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + timeout=10, + max_retries=10, + tools=None, + http_client_kwargs={"proxy": "http://localhost:8080"}, + ) + data = component.to_dict() + + assert ( + data["type"] + == "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" + ) + + expected_params = { + "api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"}, + "model": "gpt-4o-mini", + "api_base_url": "https://api.cometapi.com/v1", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "timeout": 10, + "max_retries": 10, + "tools": None, + "http_client_kwargs": {"proxy": "http://localhost:8080"}, + } + + for key, value in expected_params.items(): + assert data["init_parameters"][key] == value + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "fake-api-key") + data = { + "type": ( + "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" + ), + "init_parameters": { + "api_key": {"env_vars": ["COMET_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4o-mini", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "timeout": 10, + "max_retries": 10, + "tools": None, + "http_client_kwargs": {"proxy": "http://localhost:8080"}, + }, + } + component = CometAPIChatGenerator.from_dict(data) + assert component.model == "gpt-4o-mini" + assert component.streaming_callback is print_streaming_chunk + assert component.api_base_url == "https://api.cometapi.com/v1" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.api_key == Secret.from_env_var("COMET_API_KEY") + assert component.http_client_kwargs == {"proxy": "http://localhost:8080"} + assert component.tools is None + assert component.timeout == 10 + assert component.max_retries == 10 + + def test_from_dict_fail_wo_env_var(self, monkeypatch): + monkeypatch.delenv("COMET_API_KEY", raising=False) + data = { + "type": ( + "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" + ), + "init_parameters": { + "api_key": {"env_vars": ["COMET_API_KEY"], "strict": True, "type": "env_var"}, + "model": "gpt-4o-mini", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "timeout": 10, + "max_retries": 10, + }, + } + with pytest.raises(ValueError, match=r"None of the .* environment variables are set"): + CometAPIChatGenerator.from_dict(data) + + def test_run(self, chat_messages, mock_chat_completion, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "fake-api-key") component = CometAPIChatGenerator() - response = await component.run_async(chat_messages) + response = component.run(chat_messages) # check that the component returns the correct ChatMessage response assert isinstance(response, dict) @@ -107,17 +228,16 @@ async def test_run_async(self, chat_messages, monkeypatch): assert len(response["replies"]) == 1 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - @pytest.mark.asyncio - async def test_run_async_with_params(self, chat_messages, mock_async_chat_completion, monkeypatch): + def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "fake-api-key") component = CometAPIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) - response = await component.run_async(chat_messages) + response = component.run(chat_messages) # check that the component calls the OpenAI API with the correct parameters - _, kwargs = mock_async_chat_completion.call_args + # for cometapi, these are passed in the extra_body parameter + _, kwargs = mock_chat_completion.call_args assert kwargs["max_tokens"] == 10 assert kwargs["temperature"] == 0.5 - # check that the component returns the correct response assert isinstance(response, dict) assert "replies" in response @@ -127,18 +247,17 @@ async def test_run_async_with_params(self, chat_messages, mock_async_chat_comple @pytest.mark.skipif( not os.environ.get("COMET_API_KEY", None), - reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + reason="Export an env var called COMET_API_KEY containing the cometapi API key to run this test.", ) @pytest.mark.integration - @pytest.mark.asyncio - async def test_live_run_async(self): + def test_live_run(self): chat_messages = [ChatMessage.from_user("What's the capital of France")] component = CometAPIChatGenerator() - results = await component.run_async(chat_messages) + results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "openai/gpt-4o-mini" in message.meta["model"] + assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" @pytest.mark.skipif( @@ -146,121 +265,606 @@ async def test_live_run_async(self): reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - @pytest.mark.asyncio - async def test_live_run_streaming_async(self): - counter = 0 - responses = "" + def test_live_run_wrong_model(self, chat_messages): + component = CometAPIChatGenerator(model="something-obviously-wrong") + with pytest.raises(OpenAIError): + component.run(chat_messages) + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 - async def callback(chunk: StreamingChunk): - nonlocal counter - nonlocal responses - counter += 1 - responses += chunk.content if chunk.content else "" + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + callback = Callback() component = CometAPIChatGenerator(streaming_callback=callback) - results = await component.run_async([ChatMessage.from_user("What's the capital of France?")]) + results = component.run([ChatMessage.from_user("What's the capital of France?")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.text - assert "openai/gpt-4o-mini" in message.meta["model"] + assert "gpt-4o-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" - assert counter > 1 - assert "Paris" in responses + assert callback.counter > 1 + assert "Paris" in callback.responses + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = CometAPIChatGenerator(tools=tools) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.text is None + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" @pytest.mark.skipif( not os.environ.get("COMET_API_KEY", None), reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - @pytest.mark.asyncio - async def test_live_run_with_tools_and_response_async(self, tools): + def test_live_run_with_tools_and_response(self, tools): """ Integration test that the CometAPIChatGenerator component can run with tools and get a response. """ - initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] component = CometAPIChatGenerator(tools=tools) - results = await component.run_async(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) + results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) - assert len(results["replies"]) > 0, "No replies received" + assert len(results["replies"]) == 1 # Find the message with tool calls - tool_message = None - for message in results["replies"]: - if message.tool_call: - tool_message = message - break - - assert tool_message is not None, "No message with tool call found" - assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" - assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" - - tool_call = tool_message.tool_call - assert tool_call.id, "Tool call does not contain value for 'id' key" - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_message = results["replies"][0] + + assert isinstance(tool_message, ChatMessage) + tool_calls = tool_message.tool_calls + assert len(tool_calls) == 2 + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT) + + for tool_call in tool_calls: + assert tool_call.id is not None + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + + arguments = [tool_call.arguments for tool_call in tool_calls] + assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}] assert tool_message.meta["finish_reason"] == "tool_calls" new_messages = [ initial_messages[0], tool_message, - ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]), + ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]), ] # Pass the tool result to the model to get the final response - results = await component.run_async(new_messages) + results = component.run(new_messages) assert len(results["replies"]) == 1 final_message = results["replies"][0] - assert not final_message.tool_call + assert final_message.is_from(ChatRole.ASSISTANT) assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower() @pytest.mark.skipif( not os.environ.get("COMET_API_KEY", None), reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - @pytest.mark.asyncio - async def test_live_run_with_tools_streaming_async(self, tools): + def test_live_run_with_tools_streaming(self, tools): """ Integration test that the CometAPIChatGenerator component can run with tools and streaming. """ - - counter = 0 - tool_calls = [] - - async def callback(chunk: StreamingChunk): - nonlocal counter - nonlocal tool_calls - counter += 1 - if chunk.meta.get("tool_calls"): - tool_calls.extend(chunk.meta["tool_calls"]) - - component = CometAPIChatGenerator(tools=tools, streaming_callback=callback) - results = await component.run_async( - [ChatMessage.from_user("What's the weather like in Paris?")], + component = CometAPIChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + results = component.run( + [ChatMessage.from_user("What's the weather like in Paris and Berlin?")], generation_kwargs={"tool_choice": "auto"}, ) - assert len(results["replies"]) > 0, "No replies received" - assert counter > 1, "Streaming callback was not called multiple times" - assert tool_calls, "No tool calls received in streaming" + assert len(results["replies"]) == 1 # Find the message with tool calls - tool_message = None - for message in results["replies"]: - if message.tool_call: - tool_message = message - break - - assert tool_message is not None, "No message with tool call found" - assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" - assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" - - tool_call = tool_message.tool_call - assert tool_call.id, "Tool call does not contain value for 'id' key" - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_message = results["replies"][0] + + assert isinstance(tool_message, ChatMessage) + tool_calls = tool_message.tool_calls + assert len(tool_calls) == 2 + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT) + + for tool_call in tool_calls: + assert tool_call.id is not None + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + + arguments = [tool_call.arguments for tool_call in tool_calls] + assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}] assert tool_message.meta["finish_reason"] == "tool_calls" + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_pipeline_with_cometapi_chat_generator(self, tools): + """ + Test that the CometAPIChatGenerator component can be used in a pipeline + """ + pipeline = Pipeline() + pipeline.add_component("generator", CometAPIChatGenerator(tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + + pipeline.connect("generator", "tool_invoker") + + results = pipeline.run( + data={ + "generator": { + "messages": [ChatMessage.from_user("What's the weather like in Paris?")], + "generation_kwargs": {"tool_choice": "auto"}, + } + } + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) + + def test_serde_in_pipeline(self, monkeypatch): + """ + Test serialization/deserialization of CometAPIChatGenerator in a Pipeline, + including YAML conversion and detailed dictionary validation + """ + # Set mock API key + monkeypatch.setenv("COMET_API_KEY", "test-key") + + # Create a test tool + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters={"city": {"type": "string"}}, + function=weather, + ) + + # Create generator with specific configuration + generator = CometAPIChatGenerator( + model="gpt-4o-mini", + generation_kwargs={"temperature": 0.7}, + streaming_callback=print_streaming_chunk, + tools=[tool], + ) + + # Create and configure pipeline + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + # Get pipeline dictionary and verify its structure + pipeline_dict = pipeline.to_dict() + expected_dict = { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": ( + "haystack_integrations.components.generators.cometapi.chat.chat_generator." + "CometAPIChatGenerator" + ), + "init_parameters": { + "model": "gpt-4o-mini", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "api_base_url": "https://api.cometapi.com/v1", + "organization": None, + "generation_kwargs": {"temperature": 0.7}, + "api_key": {"type": "env_var", "env_vars": ["COMET_API_KEY"], "strict": True}, + "timeout": None, + "max_retries": None, + "tools": [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": {"city": {"type": "string"}}, + "function": "test.test_cometapi_chat_generator.weather", + } + } + ], + "tools_strict": False, + "http_client_kwargs": None + } + } + }, + "connections": [], + "connection_type_validation": True + } + + if not hasattr(pipeline, "_connection_type_validation"): + expected_dict.pop("connection_type_validation") + + # add outputs_to_string, inputs_from_state and outputs_to_state tool parameters for compatibility with + # haystack-ai>=2.12.0 + if hasattr(tool, "outputs_to_string"): + expected_dict["components"]["generator"]["init_parameters"]["tools"][0]["data"]["outputs_to_string"] = ( + tool.outputs_to_string + ) + if hasattr(tool, "inputs_from_state"): + expected_dict["components"]["generator"]["init_parameters"]["tools"][0]["data"]["inputs_from_state"] = ( + tool.inputs_from_state + ) + if hasattr(tool, "outputs_to_state"): + expected_dict["components"]["generator"]["init_parameters"]["tools"][0]["data"]["outputs_to_state"] = ( + tool.outputs_to_state + ) + + assert pipeline_dict == expected_dict + + # Test YAML serialization/deserialization + pipeline_yaml = pipeline.dumps() + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + # Verify the loaded pipeline's generator has the same configuration + loaded_generator = new_pipeline.get_component("generator") + assert loaded_generator.model == generator.model + assert loaded_generator.generation_kwargs == generator.generation_kwargs + assert loaded_generator.streaming_callback == generator.streaming_callback + assert len(loaded_generator.tools) == len(generator.tools) + assert loaded_generator.tools[0].name == generator.tools[0].name + assert loaded_generator.tools[0].description == generator.tools[0].description + assert loaded_generator.tools[0].parameters == generator.tools[0].parameters + + +class TestChatCompletionChunkConversion: + def test_handle_stream_response(self): + cometapi_chunks = [ + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk(delta=ChoiceDelta(content="", role="assistant"), index=0, native_finish_reason=None) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_zznlVyVfK0GJwY28SShJpDCh", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='{"ci'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='ty": '), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='"Paris'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction(arguments='"}'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id="call_Mh1uOyW3Ys4gwydHjNHILHGX", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + service_tier=None, + system_fingerprint="fp_34a54ae93c", + usage=None, + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='{"ci'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + function=ChoiceDeltaToolCallFunction(arguments='ty": '), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + function=ChoiceDeltaToolCallFunction(arguments='"Berli'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + function=ChoiceDeltaToolCallFunction(arguments='n"}'), + type="function", + ) + ], + ), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta(content="", role="assistant"), + finish_reason="tool_calls", + index=0, + native_finish_reason="tool_calls", + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + system_fingerprint="fp_34a54ae93c", + provider="OpenAI", + ), + ChatCompletionChunk( + id="gen-1750162525-tc7ParBHvsqd6rYhCDtK", + choices=[ + ChoiceChunk( + delta=ChoiceDelta(content="", role="assistant"), + index=0, + native_finish_reason=None, + ) + ], + created=1750162525, + model="gpt-4o-mini", + object="chat.completion.chunk", + usage=CompletionUsage( + completion_tokens=42, + prompt_tokens=55, + total_tokens=97, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=0), + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ), + provider="OpenAI", + ), + ] + + collector_callback = CollectorCallback() + llm = CometAPIChatGenerator(api_key=Secret.from_token("test-api-key")) + result = llm._handle_stream_response(cometapi_chunks, callback=collector_callback)[0] # type: ignore + + # Assert text is empty + assert result.text is None + + # Verify both tool calls were found and processed + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "call_zznlVyVfK0GJwY28SShJpDCh" + assert result.tool_calls[0].tool_name == "weather" + assert result.tool_calls[0].arguments == {"city": "Paris"} + assert result.tool_calls[1].id == "call_Mh1uOyW3Ys4gwydHjNHILHGX" + assert result.tool_calls[1].tool_name == "weather" + assert result.tool_calls[1].arguments == {"city": "Berlin"} + + # Verify meta information + assert result.meta["model"] == "gpt-4o-mini" + assert result.meta["finish_reason"] == "tool_calls" + assert result.meta["index"] == 0 + assert result.meta["completion_start_time"] is not None + assert result.meta["usage"] == { + "completion_tokens": 42, + "prompt_tokens": 55, + "total_tokens": 97, + "completion_tokens_details": { + "accepted_prediction_tokens": None, + "audio_tokens": None, + "reasoning_tokens": 0, + "rejected_prediction_tokens": None, + }, + "prompt_tokens_details": { + "audio_tokens": None, + "cached_tokens": 0, + }, + } diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator_async.py b/integrations/cometapi/tests/test_cometapi_chat_generator_async.py new file mode 100644 index 000000000..7caa40775 --- /dev/null +++ b/integrations/cometapi/tests/test_cometapi_chat_generator_async.py @@ -0,0 +1,269 @@ +import os +from datetime import datetime +from unittest.mock import AsyncMock, patch + +import pytest +import pytz +from haystack.dataclasses import ( + ChatMessage, + ChatRole, + StreamingChunk, +) +from haystack.tools import Tool +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + +from haystack_integrations.components.generators.cometapi.chat.chat_generator import ( + CometAPIChatGenerator, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France"), + ] + + +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + +@pytest.fixture +def tools(): + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather, + ) + + return [tool] + + +@pytest.fixture +def mock_async_chat_completion(): + """ + Mock the Async OpenAI API completion response and reuse it for async tests + """ + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + ) as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="openai/gpt-4o-mini", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="Hello world!", role="assistant"), + ) + ], + created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()), + usage={ + "prompt_tokens": 57, + "completion_tokens": 40, + "total_tokens": 97, + }, + ) + # For async mocks, the return value should be awaitable + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + +class TestCometAPIChatGeneratorAsync: + def test_init_default_async(self, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "test-api-key") + component = CometAPIChatGenerator() + + assert isinstance(component.async_client, AsyncOpenAI) + assert component.async_client.api_key == "test-api-key" + assert component.async_client.base_url == "https://api.cometapi.com/v1/" + assert not component.generation_kwargs + + @pytest.mark.asyncio + async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "fake-api-key") + component = CometAPIChatGenerator() + response = await component.run_async(chat_messages) + + # Verify the mock was called + mock_async_chat_completion.assert_called_once() + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.asyncio + async def test_run_async_with_params(self, chat_messages, mock_async_chat_completion, monkeypatch): + monkeypatch.setenv("COMET_API_KEY", "fake-api-key") + component = CometAPIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) + response = await component.run_async(chat_messages) + + # check that the component calls the OpenAI API with the correct parameters + _, kwargs = mock_async_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + chat_messages = [ChatMessage.from_user("What's the capital of France")] + component = CometAPIChatGenerator() + results = await component.run_async(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + assert "gpt-4o-mini" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_streaming_async(self): + counter = 0 + responses = "" + + async def callback(chunk: StreamingChunk): + nonlocal counter + nonlocal responses + counter += 1 + responses += chunk.content if chunk.content else "" + + component = CometAPIChatGenerator(streaming_callback=callback) + results = await component.run_async([ChatMessage.from_user("What's the capital of France?")]) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.text + + assert "gpt-4o-mini" in message.meta["model"] + assert message.meta["finish_reason"] == "stop" + + assert counter > 1 + assert "Paris" in responses + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_with_tools_and_response_async(self, tools): + """ + Integration test that the CometAPIChatGenerator component can run with tools and get a response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = CometAPIChatGenerator(tools=tools) + results = await component.run_async(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = await component.run_async(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.skipif( + not os.environ.get("COMET_API_KEY", None), + reason="Export an env var called COMET_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_with_tools_streaming_async(self, tools): + """ + Integration test that the CometAPIChatGenerator component can run with tools and streaming. + """ + + counter = 0 + tool_calls = [] + + async def callback(chunk: StreamingChunk): + nonlocal counter + nonlocal tool_calls + counter += 1 + if chunk.meta.get("tool_calls"): + tool_calls.extend(chunk.meta["tool_calls"]) + + component = CometAPIChatGenerator(tools=tools, streaming_callback=callback) + results = await component.run_async( + [ChatMessage.from_user("What's the weather like in Paris?")], + generation_kwargs={"tool_choice": "auto"}, + ) + + assert len(results["replies"]) > 0, "No replies received" + assert counter > 1, "Streaming callback was not called multiple times" + assert tool_calls, "No tool calls received in streaming" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_calls" From 7fcf5ba7efaf55953d52bd59bbc0b55eaa3310f7 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 10:44:08 +0530 Subject: [PATCH 40/52] refactor: simplify type declarations in TestCometAPIChatGenerator tests --- .../tests/test_cometapi_chat_generator.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 650ccad95..86ae874ec 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -173,9 +173,7 @@ def test_to_dict_with_parameters(self, monkeypatch): def test_from_dict(self, monkeypatch): monkeypatch.setenv("COMET_API_KEY", "fake-api-key") data = { - "type": ( - "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" - ), + "type": ("haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator"), "init_parameters": { "api_key": {"env_vars": ["COMET_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", @@ -201,9 +199,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("COMET_API_KEY", raising=False) data = { - "type": ( - "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" - ), + "type": ("haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator"), "init_parameters": { "api_key": {"env_vars": ["COMET_API_KEY"], "strict": True, "type": "env_var"}, "model": "gpt-4o-mini", @@ -466,8 +462,7 @@ def test_serde_in_pipeline(self, monkeypatch): "components": { "generator": { "type": ( - "haystack_integrations.components.generators.cometapi.chat.chat_generator." - "CometAPIChatGenerator" + "haystack_integrations.components.generators.cometapi.chat.chat_generator.CometAPIChatGenerator" ), "init_parameters": { "model": "gpt-4o-mini", @@ -486,16 +481,16 @@ def test_serde_in_pipeline(self, monkeypatch): "description": "useful to determine the weather in a given location", "parameters": {"city": {"type": "string"}}, "function": "test.test_cometapi_chat_generator.weather", - } + }, } ], "tools_strict": False, - "http_client_kwargs": None - } + "http_client_kwargs": None, + }, } }, "connections": [], - "connection_type_validation": True + "connection_type_validation": True, } if not hasattr(pipeline, "_connection_type_validation"): From d81d5abb07032388b7ff8b2686e16acc559613b3 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 10:47:32 +0530 Subject: [PATCH 41/52] fix: correct function reference in TestCometAPIChatGenerator --- integrations/cometapi/tests/test_cometapi_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 86ae874ec..4871dde92 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -480,7 +480,7 @@ def test_serde_in_pipeline(self, monkeypatch): "name": "weather", "description": "useful to determine the weather in a given location", "parameters": {"city": {"type": "string"}}, - "function": "test.test_cometapi_chat_generator.weather", + "function": "test_cometapi_chat_generator.weather", }, } ], From e9fabf54fdfc761b5956c7d6b0af60179118b946 Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Sat, 11 Oct 2025 11:10:48 +0530 Subject: [PATCH 42/52] refactor: remove redundant YAML serialization/deserialization in TestCometAPIChatGenerator --- .../cometapi/tests/test_cometapi_chat_generator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 4871dde92..6fd4153d6 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -513,13 +513,8 @@ def test_serde_in_pipeline(self, monkeypatch): assert pipeline_dict == expected_dict - # Test YAML serialization/deserialization - pipeline_yaml = pipeline.dumps() - new_pipeline = Pipeline.loads(pipeline_yaml) - assert new_pipeline == pipeline - # Verify the loaded pipeline's generator has the same configuration - loaded_generator = new_pipeline.get_component("generator") + loaded_generator = pipeline.get_component("generator") assert loaded_generator.model == generator.model assert loaded_generator.generation_kwargs == generator.generation_kwargs assert loaded_generator.streaming_callback == generator.streaming_callback From a85ca39424169971eafd71c5614a95b4212e4c9c Mon Sep 17 00:00:00 2001 From: Gary Badwal Date: Tue, 21 Oct 2025 16:27:00 +0530 Subject: [PATCH 43/52] Fixed the type issue --- .../components/generators/cometapi/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index f8f0cc9c3..1cc469c60 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -37,7 +37,7 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, timeout: Optional[int] = None, max_retries: Optional[int] = None, - tools: Optional[Union[List[Tool], Toolset]] = None, + tools: Optional[Union[List[Union[Tool, Toolset]], Toolset]] = None, *, tools_strict: bool = False, http_client_kwargs: Optional[Dict[str, Any]] = None, From f08628d6e251b6d0a50e7211bfc3f3e2a9d253c0 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:04:39 +0530 Subject: [PATCH 44/52] Fixes for: 1. enforcing keyword arguments for all init params 2. updated workflow --- .github/workflows/cometapi.yml | 2 +- .../components/generators/cometapi/chat/chat_generator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cometapi.yml b/.github/workflows/cometapi.yml index beb9a5eb9..5d08ba9c3 100644 --- a/.github/workflows/cometapi.yml +++ b/.github/workflows/cometapi.yml @@ -64,7 +64,7 @@ jobs: - name: Run unit tests with lowest direct dependencies run: | hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt - hatch run uv pip install -r requirements_lowest_direct.txt + hatch -e test env run -- uv pip install -r requirements_lowest_direct.txt hatch run test:unit - name: Nightly - run unit tests with Haystack main branch diff --git a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py index 1cc469c60..cc389302e 100644 --- a/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py +++ b/integrations/cometapi/src/haystack_integrations/components/generators/cometapi/chat/chat_generator.py @@ -31,6 +31,7 @@ class CometAPIChatGenerator(OpenAIChatGenerator): def __init__( self, + *, api_key: Secret = Secret.from_env_var("COMET_API_KEY"), model: str = "gpt-4o-mini", streaming_callback: Optional[StreamingCallbackT] = None, @@ -38,7 +39,6 @@ def __init__( timeout: Optional[int] = None, max_retries: Optional[int] = None, tools: Optional[Union[List[Union[Tool, Toolset]], Toolset]] = None, - *, tools_strict: bool = False, http_client_kwargs: Optional[Dict[str, Any]] = None, ): From 73d6e038efbff994eb4a3a69c0521e99ce358a9f Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:23:27 +0530 Subject: [PATCH 45/52] Fixed the test-cases --- integrations/cometapi/tests/test_cometapi_chat_generator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 6fd4153d6..42029c9dc 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -843,6 +843,9 @@ def test_handle_stream_response(self): assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 assert result.meta["completion_start_time"] is not None + + result.meta["usage"]["completion_tokens_details"] = result.meta["usage"]["completion_tokens_details"].model_dump() + result.meta["usage"]["prompt_tokens_details"] = result.meta["usage"]["prompt_tokens_details"].model_dump() assert result.meta["usage"] == { "completion_tokens": 42, "prompt_tokens": 55, From 86196504a0b2353a8d32bff1036cffc49f4e9012 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:26:35 +0530 Subject: [PATCH 46/52] Fixed the linting issue --- integrations/cometapi/tests/test_cometapi_chat_generator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 42029c9dc..87bdab8a4 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -843,8 +843,10 @@ def test_handle_stream_response(self): assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 assert result.meta["completion_start_time"] is not None - - result.meta["usage"]["completion_tokens_details"] = result.meta["usage"]["completion_tokens_details"].model_dump() + + result.meta["usage"]["completion_tokens_details"] = result.meta["usage"][ + "completion_tokens_details" + ].model_dump() result.meta["usage"]["prompt_tokens_details"] = result.meta["usage"]["prompt_tokens_details"].model_dump() assert result.meta["usage"] == { "completion_tokens": 42, From e19897c809d0ea1747d2e9571ad5bf8c4fc36719 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:35:49 +0530 Subject: [PATCH 47/52] Fix test-cases --- integrations/cometapi/tests/test_cometapi_chat_generator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 87bdab8a4..6fd4153d6 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -843,11 +843,6 @@ def test_handle_stream_response(self): assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 assert result.meta["completion_start_time"] is not None - - result.meta["usage"]["completion_tokens_details"] = result.meta["usage"][ - "completion_tokens_details" - ].model_dump() - result.meta["usage"]["prompt_tokens_details"] = result.meta["usage"]["prompt_tokens_details"].model_dump() assert result.meta["usage"] == { "completion_tokens": 42, "prompt_tokens": 55, From 12190223cd91ed580aafb4c5f15b8982f94fb416 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:42:20 +0530 Subject: [PATCH 48/52] Fix test cases --- .../cometapi/tests/test_cometapi_chat_generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index 6fd4153d6..ffd397362 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -847,14 +847,14 @@ def test_handle_stream_response(self): "completion_tokens": 42, "prompt_tokens": 55, "total_tokens": 97, - "completion_tokens_details": { + "completion_tokens_details": CompletionTokensDetails(**{ "accepted_prediction_tokens": None, "audio_tokens": None, "reasoning_tokens": 0, "rejected_prediction_tokens": None, - }, - "prompt_tokens_details": { + }), + "prompt_tokens_details": PromptTokensDetails(**{ "audio_tokens": None, "cached_tokens": 0, - }, + }), } From 6b093c6843e9707d26301f5d660c039cda280b91 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:43:43 +0530 Subject: [PATCH 49/52] Fixed formating --- .../tests/test_cometapi_chat_generator.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index ffd397362..e989cde7b 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -847,14 +847,18 @@ def test_handle_stream_response(self): "completion_tokens": 42, "prompt_tokens": 55, "total_tokens": 97, - "completion_tokens_details": CompletionTokensDetails(**{ - "accepted_prediction_tokens": None, - "audio_tokens": None, - "reasoning_tokens": 0, - "rejected_prediction_tokens": None, - }), - "prompt_tokens_details": PromptTokensDetails(**{ - "audio_tokens": None, - "cached_tokens": 0, - }), + "completion_tokens_details": CompletionTokensDetails( + **{ + "accepted_prediction_tokens": None, + "audio_tokens": None, + "reasoning_tokens": 0, + "rejected_prediction_tokens": None, + } + ), + "prompt_tokens_details": PromptTokensDetails( + **{ + "audio_tokens": None, + "cached_tokens": 0, + } + ), } From bab67b413cb32289851027410fee296f0ede8685 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Wed, 22 Oct 2025 16:52:55 +0530 Subject: [PATCH 50/52] Fixed test case --- .../tests/test_cometapi_chat_generator.py | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/integrations/cometapi/tests/test_cometapi_chat_generator.py b/integrations/cometapi/tests/test_cometapi_chat_generator.py index e989cde7b..6a3b5c17f 100644 --- a/integrations/cometapi/tests/test_cometapi_chat_generator.py +++ b/integrations/cometapi/tests/test_cometapi_chat_generator.py @@ -1,4 +1,5 @@ import os +from dataclasses import asdict from datetime import datetime from unittest.mock import patch @@ -843,22 +844,33 @@ def test_handle_stream_response(self): assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 assert result.meta["completion_start_time"] is not None - assert result.meta["usage"] == { + + # Normalize usage details before asserting + usage = result.meta["usage"] + + if hasattr(usage["completion_tokens_details"], "model_dump"): + usage["completion_tokens_details"] = usage["completion_tokens_details"].model_dump() + if hasattr(usage["prompt_tokens_details"], "model_dump"): + usage["prompt_tokens_details"] = usage["prompt_tokens_details"].model_dump() + + # For dataclass fallback + if not isinstance(usage["completion_tokens_details"], dict): + usage["completion_tokens_details"] = asdict(usage["completion_tokens_details"]) + if not isinstance(usage["prompt_tokens_details"], dict): + usage["prompt_tokens_details"] = asdict(usage["prompt_tokens_details"]) + + assert usage == { "completion_tokens": 42, "prompt_tokens": 55, "total_tokens": 97, - "completion_tokens_details": CompletionTokensDetails( - **{ - "accepted_prediction_tokens": None, - "audio_tokens": None, - "reasoning_tokens": 0, - "rejected_prediction_tokens": None, - } - ), - "prompt_tokens_details": PromptTokensDetails( - **{ - "audio_tokens": None, - "cached_tokens": 0, - } - ), + "completion_tokens_details": { + "accepted_prediction_tokens": None, + "audio_tokens": None, + "reasoning_tokens": 0, + "rejected_prediction_tokens": None, + }, + "prompt_tokens_details": { + "audio_tokens": None, + "cached_tokens": 0, + }, } From d58b90b15620d9b9f4f6e5310b4812306b846b46 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 31 Oct 2025 09:56:40 +0530 Subject: [PATCH 51/52] Updated Readme.md --- integrations/cometapi/README.md | 49 ++++++--------------------------- 1 file changed, 8 insertions(+), 41 deletions(-) diff --git a/integrations/cometapi/README.md b/integrations/cometapi/README.md index 58110ec07..271a68519 100644 --- a/integrations/cometapi/README.md +++ b/integrations/cometapi/README.md @@ -1,48 +1,15 @@ -# Comet API Haystack Integration +# cometapi-haystack [![PyPI - Version](https://img.shields.io/pypi/v/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/cometapi-haystack.svg)](https://pypi.org/project/cometapihaystack) ------ +- [Integration page](https://haystack.deepset.ai/integrations/cometapi) +- [Changelog](https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/cometapi/CHANGELOG.md) -**CometAPI Resources** -- [Website](https://www.cometapi.com/?utm_source=haystack&utm_campaign=integration&utm_medium=integration&utm_content=integration) -- [Documentation](https://api.cometapi.com/doc) -- [Get an API Key](https://api.cometapi.com/console/token) -- [Pricing](https://api.cometapi.com/pricing) +--- -**Table of Contents** +## Contributing -- [Installation](#installation) -- [Usage](#usage) -- [License](#license) +Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md). -## Installation - -```console -pip install cometapi-haystack -``` - -## Usage - -This integration offers a set of pre-built components that allow developers to interact seamlessly with AI models using the new Comet APIs. - -### Chat Generator - -```python -from haystack.dataclasses.chat_message import ChatMessage -from haystack_integrations.components.generators.cometapi import CometAPIChatGenerator - -# Initialize the chat generator -chat_generator = CometAPIChatGenerator(model="grok-3-mini") - -# Generate a response -messages = [ChatMessage.from_user("Tell me about the future of AI")] -response = chat_generator.run(messages=messages) -print(response["replies"][0].text) -``` - - -## License - -`cometapi-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file +To run integration tests locally, you need to export the `COMET_API_KEY` environment variable. From 70c8054a3d4a443ead80baf715c00fd9a1e5b552 Mon Sep 17 00:00:00 2001 From: garybadwal Date: Fri, 31 Oct 2025 10:03:29 +0530 Subject: [PATCH 52/52] Updated README.md --- integrations/cometapi/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cometapi/README.md b/integrations/cometapi/README.md index 271a68519..061529525 100644 --- a/integrations/cometapi/README.md +++ b/integrations/cometapi/README.md @@ -3,7 +3,7 @@ [![PyPI - Version](https://img.shields.io/pypi/v/cometapi-haystack.svg)](https://pypi.org/project/cometapi-haystack) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/cometapi-haystack.svg)](https://pypi.org/project/cometapihaystack) -- [Integration page](https://haystack.deepset.ai/integrations/cometapi) +- [Integration page](https://haystack.deepset.ai/integrations/comet-api) - [Changelog](https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/cometapi/CHANGELOG.md) ---