diff --git a/config/gateway/gateway-plugin/gateway-plugin.yaml b/config/gateway/gateway-plugin/gateway-plugin.yaml index d6bd9e0ad..6a06d0a69 100644 --- a/config/gateway/gateway-plugin/gateway-plugin.yaml +++ b/config/gateway/gateway-plugin/gateway-plugin.yaml @@ -165,6 +165,9 @@ spec: - path: type: PathPrefix value: /v1/completions + - path: + type: PathPrefix + value: /v1/embeddings backendRefs: - name: aibrix-gateway-plugins port: 50052 diff --git a/pkg/controller/modelrouter/modelrouter_controller.go b/pkg/controller/modelrouter/modelrouter_controller.go index 6ead4a856..7855343a4 100644 --- a/pkg/controller/modelrouter/modelrouter_controller.go +++ b/pkg/controller/modelrouter/modelrouter_controller.go @@ -228,6 +228,15 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string modelHeaderMatch, }, }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: ptr.To(gatewayv1.PathMatchPathPrefix), + Value: ptr.To("/v1/embeddings"), + }, + Headers: []gatewayv1.HTTPHeaderMatch{ + modelHeaderMatch, + }, + }, }, BackendRefs: []gatewayv1.HTTPBackendRef{ { diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go index 0e57ea540..bc2d0831a 100644 --- a/pkg/plugins/gateway/util.go +++ b/pkg/plugins/gateway/util.go @@ -18,6 +18,7 @@ package gateway import ( "encoding/json" + "fmt" "strings" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -69,6 +70,40 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user } model = completionObj.Model message = completionObj.Prompt + } else if requestPath == "/v1/embeddings" { + var embeddingReq struct { + Input interface{} `json:"input"` + Model string `json:"model"` + } + if err := json.Unmarshal(requestBody, &embeddingReq); err != nil { + klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody)) + errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true") + return + } + model = embeddingReq.Model + // Convert input to string for message + switch v := embeddingReq.Input.(type) { + case string: + message = v + case []interface{}: + // Handle array inputs + if len(v) > 0 { + switch elem := v[0].(type) { + case string: + message = elem + case float64: + // Handle token ID (number) + message = fmt.Sprintf("Token array input (first token: %v)", elem) + case []interface{}: + // Handle nested array (number[][]) + if len(elem) > 0 { + if token, ok := elem[0].(float64); ok { + message = fmt.Sprintf("Nested token array input (first token: %v)", token) + } + } + } + } + } } else { errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true") return diff --git a/python/aibrix/aibrix/app.py b/python/aibrix/aibrix/app.py index a16dd1189..529f72e8d 100644 --- a/python/aibrix/aibrix/app.py +++ b/python/aibrix/aibrix/app.py @@ -29,6 +29,7 @@ from aibrix.openapi.model import ModelManager from aibrix.openapi.protocol import ( DownloadModelRequest, + EmbeddingRequest, ErrorResponse, ListModelRequest, LoadLoraAdapterRequest, @@ -188,6 +189,14 @@ async def readiness_check(): return JSONResponse(content={"status": "not ready"}, status_code=503) +@router.post("/v1/embeddings") +async def create_embeddings(request: EmbeddingRequest, raw_request: Request): + response = await inference_engine(raw_request).create_embeddings(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), status_code=response.code) + return JSONResponse(status_code=200, content=response.model_dump()) + + def build_app(args: argparse.Namespace): if args.enable_fastapi_docs: app = FastAPI(debug=False) diff --git a/python/aibrix/aibrix/openapi/engine/base.py b/python/aibrix/aibrix/openapi/engine/base.py index c3604cd89..3c5124908 100644 --- a/python/aibrix/aibrix/openapi/engine/base.py +++ b/python/aibrix/aibrix/openapi/engine/base.py @@ -20,6 +20,8 @@ from packaging.version import Version from aibrix.openapi.protocol import ( + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, @@ -71,6 +73,16 @@ async def list_models(self) -> Union[ErrorResponse, str]: status_code=HTTPStatus.NOT_IMPLEMENTED, ) + async def create_embeddings( + self, request: EmbeddingRequest + ) -> Union[ErrorResponse, EmbeddingResponse]: + return self._create_error_response( + f"Inference engine {self.name} with version {self.version} " + "not support embeddings", + err_type="NotImplementedError", + status_code=HTTPStatus.NOT_IMPLEMENTED, + ) + def get_inference_engine(engine: str, version: str, endpoint: str) -> InferenceEngine: if engine.lower() == "vllm": diff --git a/python/aibrix/aibrix/openapi/engine/vllm.py b/python/aibrix/aibrix/openapi/engine/vllm.py index 2c9add91a..0eddc2cb6 100644 --- a/python/aibrix/aibrix/openapi/engine/vllm.py +++ b/python/aibrix/aibrix/openapi/engine/vllm.py @@ -21,6 +21,8 @@ from aibrix.logger import init_logger from aibrix.openapi.engine.base import InferenceEngine from aibrix.openapi.protocol import ( + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, @@ -139,3 +141,37 @@ async def list_models(self) -> Union[ErrorResponse, str]: err_type="ServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) + + async def create_embeddings( + self, request: EmbeddingRequest + ) -> Union[ErrorResponse, EmbeddingResponse]: + embeddings_url = urljoin(self.endpoint, "/v1/embeddings") + + try: + response = await self.client.post( + embeddings_url, json=request.model_dump(), headers=self.headers + ) + except httpx.RequestError as e: + logger.error(f"Failed to create embeddings: {e}") + return self._create_error_response( + "Failed to create embeddings", + err_type="ServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + if response.status_code != HTTPStatus.OK: + return self._create_error_response( + f"Failed to create embeddings: {response.text}", + err_type="ServerError", + status_code=HTTPStatus(value=response.status_code), + ) + + try: + return EmbeddingResponse(**response.json()) + except Exception as e: + logger.error(f"Failed to parse embedding response: {e}") + return self._create_error_response( + "Invalid response from inference engine", + err_type="ServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) diff --git a/python/aibrix/aibrix/openapi/protocol.py b/python/aibrix/aibrix/openapi/protocol.py index 47d34d9c4..0951f7f63 100644 --- a/python/aibrix/aibrix/openapi/protocol.py +++ b/python/aibrix/aibrix/openapi/protocol.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -66,3 +66,29 @@ class ListModelRequest(NoExtraBaseModel): class ListModelResponse(NoExtraBaseModel): object: str = "list" data: List[ModelStatusCard] = Field(default_factory=list) + + +class EmbeddingRequest(NoExtraBaseModel): + input: Union[str, List[str], List[int], List[List[int]]] + model: str + encoding_format: Optional[Literal["float", "base64"]] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + + +class EmbeddingData(NoExtraBaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], str] # float array or base64 string + index: int + + +class EmbeddingUsage(NoExtraBaseModel): + prompt_tokens: int + total_tokens: int + + +class EmbeddingResponse(NoExtraBaseModel): + object: Literal["list"] = "list" + data: List[EmbeddingData] + model: str + usage: EmbeddingUsage diff --git a/python/aibrix/tests/test_embedding_integration.py b/python/aibrix/tests/test_embedding_integration.py new file mode 100644 index 000000000..d4dcd23c6 --- /dev/null +++ b/python/aibrix/tests/test_embedding_integration.py @@ -0,0 +1,416 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi.testclient import TestClient + +from aibrix.app import build_app +from aibrix.openapi.engine.vllm import VLLMInferenceEngine +from aibrix.openapi.protocol import ( + EmbeddingData, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingUsage, + ErrorResponse, +) + + +class TestVLLMInferenceEngineEmbeddings: + def setup_method(self): + """Set up test fixtures.""" + self.engine = VLLMInferenceEngine( + name="vllm", + version="0.6.1", + endpoint="http://localhost:8000", + ) + + @pytest.mark.asyncio + async def test_create_embeddings_success(self): + """Test successful embeddings creation.""" + # Mock the VLLM response + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3, 0.4], + "index": 0, + } + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="Hello world", + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + # Verify the result + assert isinstance(result, EmbeddingResponse) + assert result.object == "list" + assert len(result.data) == 1 + assert result.data[0].embedding == [0.1, 0.2, 0.3, 0.4] + assert result.model == "text-embedding-ada-002" + assert result.usage.prompt_tokens == 4 + + # Verify the HTTP call + mock_post.assert_called_once_with( + "http://localhost:8000/v1/embeddings", + json=request.model_dump(), + headers=self.engine.headers, + ) + + @pytest.mark.asyncio + async def test_create_embeddings_batch_input(self): + """Test embeddings creation with batch input.""" + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3], + "index": 0, + }, + { + "object": "embedding", + "embedding": [0.4, 0.5, 0.6], + "index": 1, + }, + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 8, "total_tokens": 8}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, EmbeddingResponse) + assert len(result.data) == 2 + assert result.data[0].index == 0 + assert result.data[1].index == 1 + + @pytest.mark.asyncio + async def test_create_embeddings_http_error(self): + """Test embeddings creation with HTTP error.""" + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.BAD_REQUEST + mock_http_response.text = "Invalid model" + + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="Hello world", + model="invalid-model", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, ErrorResponse) + assert result.type == "ServerError" + assert "Failed to create embeddings: Invalid model" in result.message + assert result.code == HTTPStatus.BAD_REQUEST + + @pytest.mark.asyncio + async def test_create_embeddings_network_error(self): + """Test embeddings creation with network error.""" + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.side_effect = httpx.ConnectError("Connection failed") + + request = EmbeddingRequest( + input="Hello world", + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, ErrorResponse) + assert result.type == "ServerError" + assert result.message == "Failed to create embeddings" + assert result.code == HTTPStatus.INTERNAL_SERVER_ERROR + + @pytest.mark.asyncio + async def test_create_embeddings_with_api_key(self): + """Test embeddings creation with API key authentication.""" + engine_with_key = VLLMInferenceEngine( + name="vllm", + version="0.6.1", + endpoint="http://localhost:8000", + api_key="test-api-key", + ) + + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2], + "index": 0, + } + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 2, "total_tokens": 2}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object( + engine_with_key.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + ) + + result = await engine_with_key.create_embeddings(request) + + assert isinstance(result, EmbeddingResponse) + # Verify that the client was created with the Authorization header + assert "Authorization" in engine_with_key.client.headers + assert ( + engine_with_key.client.headers["Authorization"] + == "Bearer test-api-key" + ) + + +class TestEmbeddingsAPIEndpoint: + @pytest.fixture + def app(self): + """Create a test FastAPI app.""" + import argparse + + args = argparse.Namespace(enable_fastapi_docs=True) + return build_app(args) + + @pytest.fixture + def client(self, app): + """Create a test client.""" + return TestClient(app) + + def test_embeddings_endpoint_success(self, client): + """Test the /v1/embeddings endpoint with successful response.""" + # Mock the inference engine + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "text-embedding-ada-002", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["embedding"] == [0.1, 0.2, 0.3] + assert data["model"] == "text-embedding-ada-002" + + def test_embeddings_endpoint_error(self, client): + """Test the /v1/embeddings endpoint with error response.""" + mock_error = ErrorResponse( + message="Model not found", + type="NotFoundError", + code=404, + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_error) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "non-existent-model", + }, + ) + + assert response.status_code == 404 + data = response.json() + assert data["message"] == "Model not found" + assert data["type"] == "NotFoundError" + + def test_embeddings_endpoint_validation_error(self, client): + """Test the /v1/embeddings endpoint with validation error.""" + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + # Missing required 'model' field + }, + ) + + assert response.status_code == 422 # Validation error + data = response.json() + assert "detail" in data + + def test_embeddings_endpoint_different_input_types(self, client): + """Test the endpoint with different input types.""" + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2], index=0), + EmbeddingData(embedding=[0.3, 0.4], index=1), + ], + model="test-model", + usage=EmbeddingUsage(prompt_tokens=4, total_tokens=4), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + # Test with string array + response = client.post( + "/v1/embeddings", + json={ + "input": ["Hello", "World"], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + # Test with token array + response = client.post( + "/v1/embeddings", + json={ + "input": [101, 102, 103], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + # Test with nested token array + response = client.post( + "/v1/embeddings", + json={ + "input": [[101, 102], [103, 104]], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + def test_embeddings_endpoint_optional_parameters(self, client): + """Test the endpoint with optional parameters.""" + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding="base64encodedstring", index=0), + ], + model="test-model", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "test-model", + "encoding_format": "base64", + "dimensions": 256, + "user": "test-user", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["data"][0]["embedding"] == "base64encodedstring" + + # Verify the request was passed correctly + call_args = mock_engine.create_embeddings.call_args[0][0] + assert call_args.encoding_format == "base64" + assert call_args.dimensions == 256 + assert call_args.user == "test-user" + + +class TestBaseInferenceEngineEmbeddings: + def test_base_engine_not_implemented(self): + """Test that base inference engine returns NotImplementedError.""" + from aibrix.openapi.engine.base import InferenceEngine + + engine = InferenceEngine( + name="test", + version="1.0", + endpoint="http://localhost:8000", + ) + + request = EmbeddingRequest( + input="test", + model="test-model", + ) + + import asyncio + + result = asyncio.run(engine.create_embeddings(request)) + + assert isinstance(result, ErrorResponse) + assert result.type == "NotImplementedError" + assert result.code == HTTPStatus.NOT_IMPLEMENTED + assert "not support embeddings" in result.message + diff --git a/python/aibrix/tests/test_embedding_protocol.py b/python/aibrix/tests/test_embedding_protocol.py new file mode 100644 index 000000000..e1ee47b00 --- /dev/null +++ b/python/aibrix/tests/test_embedding_protocol.py @@ -0,0 +1,294 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from aibrix.openapi.protocol import ( + EmbeddingData, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingUsage, +) + + +class TestEmbeddingRequest: + def test_valid_string_input(self): + """Test EmbeddingRequest with valid string input.""" + request = EmbeddingRequest( + input="Hello, world!", + model="text-embedding-ada-002", + ) + assert request.input == "Hello, world!" + assert request.model == "text-embedding-ada-002" + assert request.encoding_format == "float" # default value + assert request.dimensions is None + assert request.user is None + + def test_valid_string_array_input(self): + """Test EmbeddingRequest with string array input.""" + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + ) + assert request.input == ["Hello", "World"] + assert request.model == "text-embedding-ada-002" + + def test_valid_token_array_input(self): + """Test EmbeddingRequest with token array input.""" + request = EmbeddingRequest( + input=[101, 102, 103], + model="text-embedding-ada-002", + ) + assert request.input == [101, 102, 103] + + def test_valid_nested_token_array_input(self): + """Test EmbeddingRequest with nested token array input.""" + request = EmbeddingRequest( + input=[[101, 102], [103, 104]], + model="text-embedding-ada-002", + ) + assert request.input == [[101, 102], [103, 104]] + + def test_optional_parameters(self): + """Test EmbeddingRequest with optional parameters.""" + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + encoding_format="base64", + dimensions=512, + user="user123", + ) + assert request.encoding_format == "base64" + assert request.dimensions == 512 + assert request.user == "user123" + + def test_invalid_encoding_format(self): + """Test EmbeddingRequest with invalid encoding format.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + encoding_format="invalid", + ) + assert "Input should be 'float' or 'base64'" in str(exc_info.value) + + def test_missing_required_fields(self): + """Test EmbeddingRequest with missing required fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest(input="test") + assert "Field required" in str(exc_info.value) + + def test_extra_fields_not_allowed(self): + """Test that extra fields are not allowed.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + extra_field="not_allowed", + ) + assert "Extra inputs are not permitted" in str(exc_info.value) + + def test_model_dump(self): + """Test serialization of EmbeddingRequest.""" + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + dimensions=256, + ) + data = request.model_dump() + assert data["input"] == "test" + assert data["model"] == "text-embedding-ada-002" + assert data["encoding_format"] == "float" + assert data["dimensions"] == 256 + assert data["user"] is None + + +class TestEmbeddingData: + def test_valid_float_embedding(self): + """Test EmbeddingData with float array embedding.""" + data = EmbeddingData( + embedding=[0.1, 0.2, 0.3, 0.4], + index=0, + ) + assert data.object == "embedding" + assert data.embedding == [0.1, 0.2, 0.3, 0.4] + assert data.index == 0 + + def test_valid_base64_embedding(self): + """Test EmbeddingData with base64 string embedding.""" + data = EmbeddingData( + embedding="base64encodedstring", + index=1, + ) + assert data.object == "embedding" + assert data.embedding == "base64encodedstring" + assert data.index == 1 + + def test_object_literal_fixed(self): + """Test that object field is always 'embedding'.""" + data = EmbeddingData( + embedding=[0.1, 0.2], + index=0, + ) + assert data.object == "embedding" + # Cannot override the literal value + + def test_missing_required_fields(self): + """Test EmbeddingData with missing required fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingData(embedding=[0.1, 0.2]) + assert "Field required" in str(exc_info.value) + + +class TestEmbeddingUsage: + def test_valid_usage(self): + """Test EmbeddingUsage with valid token counts.""" + usage = EmbeddingUsage( + prompt_tokens=10, + total_tokens=10, + ) + assert usage.prompt_tokens == 10 + assert usage.total_tokens == 10 + + def test_zero_tokens(self): + """Test EmbeddingUsage with zero tokens.""" + usage = EmbeddingUsage( + prompt_tokens=0, + total_tokens=0, + ) + assert usage.prompt_tokens == 0 + assert usage.total_tokens == 0 + + def test_missing_fields(self): + """Test EmbeddingUsage with missing fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingUsage(prompt_tokens=10) + assert "Field required" in str(exc_info.value) + + +class TestEmbeddingResponse: + def test_valid_response(self): + """Test EmbeddingResponse with valid data.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + EmbeddingData(embedding=[0.4, 0.5, 0.6], index=1), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=6, total_tokens=6), + ) + assert response.object == "list" + assert len(response.data) == 2 + assert response.model == "text-embedding-ada-002" + assert response.usage.prompt_tokens == 6 + assert response.usage.total_tokens == 6 + + def test_empty_data_list(self): + """Test EmbeddingResponse with empty data list.""" + response = EmbeddingResponse( + data=[], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=0, total_tokens=0), + ) + assert len(response.data) == 0 + + def test_model_dump_json(self): + """Test JSON serialization of EmbeddingResponse.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2], index=0), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + data = response.model_dump() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["object"] == "embedding" + assert data["data"][0]["embedding"] == [0.1, 0.2] + assert data["data"][0]["index"] == 0 + assert data["model"] == "text-embedding-ada-002" + assert data["usage"]["prompt_tokens"] == 3 + assert data["usage"]["total_tokens"] == 3 + + def test_base64_response(self): + """Test EmbeddingResponse with base64 encoded embeddings.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding="base64string1", index=0), + EmbeddingData(embedding="base64string2", index=1), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=10, total_tokens=10), + ) + assert response.data[0].embedding == "base64string1" + assert response.data[1].embedding == "base64string2" + + +class TestEmbeddingProtocolIntegration: + def test_request_response_cycle(self): + """Test a complete request-response cycle.""" + # Create a request + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + encoding_format="float", + ) + + # Simulate processing and create response + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + EmbeddingData(embedding=[0.4, 0.5, 0.6], index=1), + ], + model=request.model, + usage=EmbeddingUsage(prompt_tokens=4, total_tokens=4), + ) + + # Verify the response matches the request + assert response.model == request.model + assert len(response.data) == len(request.input) + + def test_mixed_input_types(self): + """Test that different input types are properly validated.""" + # Valid cases + valid_inputs = [ + "single string", + ["multiple", "strings"], + [1, 2, 3, 4], + [[1, 2], [3, 4]], + ] + + for input_val in valid_inputs: + request = EmbeddingRequest( + input=input_val, + model="test-model", + ) + assert request.input == input_val + + # Invalid cases + invalid_inputs = [ + {"dict": "not allowed"}, + 12.34, # float not allowed + True, # boolean not allowed + ] + + for input_val in invalid_inputs: + with pytest.raises(ValidationError): + EmbeddingRequest( + input=input_val, + model="test-model", + ) +