diff --git a/meilisearch/index.py b/meilisearch/index.py index 812c3a57..a3a79f6b 100644 --- a/meilisearch/index.py +++ b/meilisearch/index.py @@ -25,6 +25,7 @@ from meilisearch.errors import version_error_hint_message from meilisearch.models.document import Document, DocumentsResults from meilisearch.models.embedders import ( + CompositeEmbedder, Embedders, EmbedderType, HuggingFaceEmbedder, @@ -977,6 +978,8 @@ def get_settings(self) -> Dict[str, Any]: embedders[k] = HuggingFaceEmbedder(**v) elif v.get("source") == "rest": embedders[k] = RestEmbedder(**v) + elif v.get("source") == "composite": + embedders[k] = CompositeEmbedder(**v) else: embedders[k] = UserProvidedEmbedder(**v) @@ -1934,6 +1937,8 @@ def get_embedders(self) -> Embedders | None: embedders[k] = OllamaEmbedder(**v) elif source == "rest": embedders[k] = RestEmbedder(**v) + elif source == "composite": + embedders[k] = CompositeEmbedder(**v) elif source == "userProvided": embedders[k] = UserProvidedEmbedder(**v) else: @@ -1977,6 +1982,8 @@ def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskI embedders[k] = OllamaEmbedder(**v) elif source == "rest": embedders[k] = RestEmbedder(**v) + elif source == "composite": + embedders[k] = CompositeEmbedder(**v) elif source == "userProvided": embedders[k] = UserProvidedEmbedder(**v) else: diff --git a/meilisearch/models/embedders.py b/meilisearch/models/embedders.py index 01ba7b3c..9dcd5d00 100644 --- a/meilisearch/models/embedders.py +++ b/meilisearch/models/embedders.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import Enum from typing import Any, Dict, Optional, Union from camel_converter.pydantic_base import CamelBase @@ -20,6 +21,24 @@ class Distribution(CamelBase): sigma: float +class PoolingType(str, Enum): + """Pooling strategies for HuggingFaceEmbedder. + + Attributes + ---------- + USE_MODEL : str + Use the model's default pooling strategy. + FORCE_MEAN : str + Force mean pooling over the token embeddings. + FORCE_CLS : str + Use the [CLS] token embedding as the sentence representation. + """ + + USE_MODEL = "useModel" + FORCE_MEAN = "forceMean" + FORCE_CLS = "forceCls" + + class OpenAiEmbedder(CamelBase): """OpenAI embedder configuration. @@ -79,6 +98,8 @@ class HuggingFaceEmbedder(CamelBase): Describes the natural distribution of search results binary_quantized: Optional[bool] Once set to true, irreversibly converts all vector dimensions to 1-bit values + pooling: Optional[PoolingType] + Configures how individual tokens are merged into a single embedding """ source: str = "huggingFace" @@ -90,6 +111,7 @@ class HuggingFaceEmbedder(CamelBase): document_template_max_bytes: Optional[int] = None # Default to 400 distribution: Optional[Distribution] = None binary_quantized: Optional[bool] = None + pooling: Optional[PoolingType] = PoolingType.USE_MODEL class OllamaEmbedder(CamelBase): @@ -191,6 +213,45 @@ class UserProvidedEmbedder(CamelBase): binary_quantized: Optional[bool] = None +class CompositeEmbedder(CamelBase): + """Composite embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "composite" + indexing_embedder: Union[ + OpenAiEmbedder, + HuggingFaceEmbedder, + OllamaEmbedder, + RestEmbedder, + UserProvidedEmbedder, + ] + search_embedder: Union[ + OpenAiEmbedder, + HuggingFaceEmbedder, + OllamaEmbedder, + RestEmbedder, + UserProvidedEmbedder, + ]""" + + source: str = "composite" + search_embedder: Union[ + OpenAiEmbedder, + HuggingFaceEmbedder, + OllamaEmbedder, + RestEmbedder, + UserProvidedEmbedder, + ] + indexing_embedder: Union[ + OpenAiEmbedder, + HuggingFaceEmbedder, + OllamaEmbedder, + RestEmbedder, + UserProvidedEmbedder, + ] + + # Type alias for the embedder union type EmbedderType = Union[ OpenAiEmbedder, @@ -198,6 +259,7 @@ class UserProvidedEmbedder(CamelBase): OllamaEmbedder, RestEmbedder, UserProvidedEmbedder, + CompositeEmbedder, ] diff --git a/tests/conftest.py b/tests/conftest.py index 54e53e4e..207d086c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -274,3 +274,20 @@ def new_embedders(): "default": UserProvidedEmbedder(dimensions=1).model_dump(by_alias=True), "open_ai": OpenAiEmbedder().model_dump(by_alias=True), } + + +@fixture +def enable_composite_embedders(): + requests.patch( + f"{common.BASE_URL}/experimental-features", + headers={"Authorization": f"Bearer {common.MASTER_KEY}"}, + json={"compositeEmbedders": True}, + timeout=10, + ) + yield + requests.patch( + f"{common.BASE_URL}/experimental-features", + headers={"Authorization": f"Bearer {common.MASTER_KEY}"}, + json={"compositeEmbedders": False}, + timeout=10, + ) diff --git a/tests/settings/test_settings_embedders.py b/tests/settings/test_settings_embedders.py index 5baf2e09..333678bc 100644 --- a/tests/settings/test_settings_embedders.py +++ b/tests/settings/test_settings_embedders.py @@ -1,6 +1,14 @@ # pylint: disable=redefined-outer-name -from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder +import pytest + +from meilisearch.models.embedders import ( + CompositeEmbedder, + HuggingFaceEmbedder, + OpenAiEmbedder, + PoolingType, + UserProvidedEmbedder, +) def test_get_default_embedders(empty_index): @@ -97,6 +105,7 @@ def test_huggingface_embedder_format(empty_index): assert embedders.embedders["huggingface"].distribution.mean == 0.5 assert embedders.embedders["huggingface"].distribution.sigma == 0.1 assert embedders.embedders["huggingface"].binary_quantized is False + assert embedders.embedders["huggingface"].pooling is PoolingType.USE_MODEL def test_ollama_embedder_format(empty_index): @@ -183,3 +192,43 @@ def test_user_provided_embedder_format(empty_index): assert embedders.embedders["user_provided"].distribution.mean == 0.5 assert embedders.embedders["user_provided"].distribution.sigma == 0.1 assert embedders.embedders["user_provided"].binary_quantized is False + + +@pytest.mark.usefixtures("enable_composite_embedders") +def test_composite_embedder_format(empty_index): + """Tests that CompositeEmbedder embedder has the required fields and proper format.""" + index = empty_index() + + embedder = HuggingFaceEmbedder().model_dump(by_alias=True, exclude_none=True) + + # create composite embedder + composite_embedder = { + "composite": { + "source": "composite", + "searchEmbedder": embedder, + "indexingEmbedder": embedder, + } + } + + response = index.update_embedders(composite_embedder) + update = index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert update.status == "succeeded" + + assert embedders.embedders["composite"].source == "composite" + + # ensure serialization roundtrips nicely + assert isinstance(embedders.embedders["composite"], CompositeEmbedder) + assert isinstance(embedders.embedders["composite"].search_embedder, HuggingFaceEmbedder) + assert isinstance(embedders.embedders["composite"].indexing_embedder, HuggingFaceEmbedder) + + # ensure search_embedder has no document_template + assert getattr(embedders.embedders["composite"].search_embedder, "document_template") is None + assert ( + getattr( + embedders.embedders["composite"].search_embedder, + "document_template_max_bytes", + ) + is None + ) + assert getattr(embedders.embedders["composite"].indexing_embedder, "document_template")