From 6a7978e08496b615317881e1ebbd7be8f10d4ea0 Mon Sep 17 00:00:00 2001 From: dittops Date: Wed, 18 Jun 2025 13:52:56 +0000 Subject: [PATCH 1/6] feat: implement /v1/embeddings endpoint for OpenAI-compatible embeddings support Signed-off-by: dittops --- EMBEDDINGS_GUIDE.md | 476 ++++++++++++++++++ pkg/plugins/gateway/util.go | 23 + python/aibrix/aibrix/app.py | 9 + python/aibrix/aibrix/openapi/engine/base.py | 12 + python/aibrix/aibrix/openapi/engine/vllm.py | 28 ++ python/aibrix/aibrix/openapi/protocol.py | 28 +- .../tests/test_embedding_integration.py | 406 +++++++++++++++ .../aibrix/tests/test_embedding_protocol.py | 293 +++++++++++ 8 files changed, 1274 insertions(+), 1 deletion(-) create mode 100644 EMBEDDINGS_GUIDE.md create mode 100644 python/aibrix/tests/test_embedding_integration.py create mode 100644 python/aibrix/tests/test_embedding_protocol.py diff --git a/EMBEDDINGS_GUIDE.md b/EMBEDDINGS_GUIDE.md new file mode 100644 index 000000000..155863e37 --- /dev/null +++ b/EMBEDDINGS_GUIDE.md @@ -0,0 +1,476 @@ +# AIBrix Embeddings API Guide + +This guide provides comprehensive documentation for using the `/v1/embeddings` endpoint in AIBrix. + +## Overview + +The `/v1/embeddings` endpoint enables you to generate vector embeddings from text inputs using various embedding models. This endpoint follows the OpenAI embeddings API specification, ensuring compatibility with existing tools and libraries. + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [API Reference](#api-reference) +- [Usage Examples](#usage-examples) +- [Configuration](#configuration) +- [Error Handling](#error-handling) +- [Performance Considerations](#performance-considerations) + +## Prerequisites + +### vLLM Configuration + +To use embeddings with AIBrix, you need: + +1. **vLLM version 0.4.3 or higher** +2. **Models loaded with embedding task support** + +```bash +# Start vLLM with embedding support +python -m vllm.entrypoints.openai.api_server \ + --model sentence-transformers/all-MiniLM-L6-v2 \ + --task embed \ + --port 8000 +``` + +### Supported Models + +Common embedding models that work with vLLM: +- `sentence-transformers/all-MiniLM-L6-v2` +- `sentence-transformers/all-mpnet-base-v2` +- `intfloat/e5-large-v2` +- `BAAI/bge-large-en-v1.5` + +## API Reference + +### Request Format + +```http +POST /v1/embeddings +Content-Type: application/json + +{ + "input": "string | string[] | number[] | number[][]", + "model": "string", + "encoding_format": "float | base64", // optional, default: "float" + "dimensions": "number", // optional + "user": "string" // optional +} +``` + +### Response Format + +```json +{ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3, ...], + "index": 0 + } + ], + "model": "sentence-transformers/all-MiniLM-L6-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } +} +``` + +### Parameters + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `input` | `string \| string[] \| number[] \| number[][]` | Yes | Text input(s) to generate embeddings for | +| `model` | `string` | Yes | ID of the model to use | +| `encoding_format` | `string` | No | Format to return embeddings in: `"float"` or `"base64"` | +| `dimensions` | `integer` | No | Number of dimensions for the embedding (model-dependent) | +| `user` | `string` | No | Unique identifier for the user | + +## Usage Examples + +### Single Text Input + +```python +import httpx + +async def get_embedding(): + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={ + "input": "The quick brown fox jumps over the lazy dog", + "model": "sentence-transformers/all-MiniLM-L6-v2" + } + ) + return response.json() + +# Example response: +# { +# "object": "list", +# "data": [ +# { +# "object": "embedding", +# "embedding": [0.012, -0.045, 0.123, ...], +# "index": 0 +# } +# ], +# "model": "sentence-transformers/all-MiniLM-L6-v2", +# "usage": { +# "prompt_tokens": 9, +# "total_tokens": 9 +# } +# } +``` + +### Batch Text Inputs + +```python +async def get_batch_embeddings(): + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={ + "input": [ + "Hello world", + "How are you?", + "Goodbye!" + ], + "model": "sentence-transformers/all-MiniLM-L6-v2" + } + ) + return response.json() + +# Example response: +# { +# "object": "list", +# "data": [ +# { +# "object": "embedding", +# "embedding": [0.012, -0.045, ...], +# "index": 0 +# }, +# { +# "object": "embedding", +# "embedding": [0.034, -0.067, ...], +# "index": 1 +# }, +# { +# "object": "embedding", +# "embedding": [0.056, -0.089, ...], +# "index": 2 +# } +# ], +# "model": "sentence-transformers/all-MiniLM-L6-v2", +# "usage": { +# "prompt_tokens": 6, +# "total_tokens": 6 +# } +# } +``` + +### Base64 Encoding + +```python +async def get_base64_embeddings(): + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={ + "input": "Convert this to base64", + "model": "sentence-transformers/all-MiniLM-L6-v2", + "encoding_format": "base64" + } + ) + return response.json() + +# The embedding will be returned as a base64-encoded string +``` + +### Token Array Input + +```python +async def get_token_embeddings(): + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={ + "input": [101, 7592, 2088, 102], # Tokenized input + "model": "sentence-transformers/all-MiniLM-L6-v2" + } + ) + return response.json() +``` + +### Using with OpenAI Client + +```python +from openai import OpenAI + +# Configure client to use AIBrix +client = OpenAI( + base_url="http://localhost:8080/v1", + api_key="not-needed" # AIBrix doesn't require API key by default +) + +# Generate embeddings +response = client.embeddings.create( + input="Your text here", + model="sentence-transformers/all-MiniLM-L6-v2" +) + +embedding = response.data[0].embedding +print(f"Embedding dimension: {len(embedding)}") +``` + +## Configuration + +### Environment Variables + +Configure AIBrix for embeddings support: + +```bash +# Set the inference engine +export INFERENCE_ENGINE=vllm +export INFERENCE_ENGINE_VERSION=0.6.1 +export INFERENCE_ENGINE_ENDPOINT=http://localhost:8000 + +# Optional: Set routing strategy +export ROUTING_ALGORITHM=random +``` + +### Model Configuration + +Ensure your embedding model is properly configured in vLLM: + +```bash +# Example vLLM startup with specific parameters +python -m vllm.entrypoints.openai.api_server \ + --model sentence-transformers/all-MiniLM-L6-v2 \ + --task embed \ + --port 8000 \ + --max-model-len 512 \ + --trust-remote-code +``` + +## Error Handling + +### Common Error Responses + +#### Model Not Found (404) +```json +{ + "object": "error", + "message": "Model 'non-existent-model' not found", + "type": "NotFoundError", + "code": 404 +} +``` + +#### Invalid Input Format (400) +```json +{ + "object": "error", + "message": "Invalid input format", + "type": "BadRequestError", + "code": 400 +} +``` + +#### Model Doesn't Support Embeddings (501) +```json +{ + "object": "error", + "message": "Inference engine vllm with version 0.6.1 not support embeddings", + "type": "NotImplementedError", + "code": 501 +} +``` + +### Error Handling in Code + +```python +async def safe_get_embeddings(text: str, model: str): + try: + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={"input": text, "model": model} + ) + + if response.status_code == 200: + return response.json() + else: + error_data = response.json() + print(f"Error {response.status_code}: {error_data['message']}") + return None + + except httpx.RequestError as e: + print(f"Network error: {e}") + return None +``` + +## Performance Considerations + +### Batch Processing + +For better performance, batch multiple inputs together: + +```python +# Instead of making multiple single requests +texts = ["text1", "text2", "text3", "text4", "text5"] + +# Batch them together +response = await client.post( + "/v1/embeddings", + json={ + "input": texts, # Send all at once + "model": "your-model" + } +) +``` + +### Optimal Batch Sizes + +- **Small models**: 50-100 texts per batch +- **Large models**: 10-20 texts per batch +- **Very large models**: 1-5 texts per batch + +Monitor memory usage and adjust accordingly. + +### Caching + +Consider caching embeddings for frequently used texts: + +```python +import hashlib +from typing import Dict, List + +class EmbeddingCache: + def __init__(self): + self.cache: Dict[str, List[float]] = {} + + def get_cache_key(self, text: str, model: str) -> str: + return hashlib.md5(f"{text}:{model}".encode()).hexdigest() + + async def get_embedding(self, text: str, model: str) -> List[float]: + cache_key = self.get_cache_key(text, model) + + if cache_key in self.cache: + return self.cache[cache_key] + + # Get embedding from API + response = await self.fetch_embedding(text, model) + embedding = response["data"][0]["embedding"] + + # Cache the result + self.cache[cache_key] = embedding + return embedding +``` + +## RAG Integration Example + +Here's how to use embeddings in a Retrieval-Augmented Generation (RAG) system: + +```python +import numpy as np +from typing import List, Tuple + +class SimpleRAG: + def __init__(self, embedding_model: str): + self.embedding_model = embedding_model + self.documents: List[str] = [] + self.embeddings: List[List[float]] = [] + + async def add_document(self, text: str): + """Add a document to the knowledge base.""" + # Get embedding for the document + response = await self.get_embedding(text) + embedding = response["data"][0]["embedding"] + + self.documents.append(text) + self.embeddings.append(embedding) + + async def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]: + """Search for relevant documents.""" + # Get query embedding + response = await self.get_embedding(query) + query_embedding = np.array(response["data"][0]["embedding"]) + + # Calculate similarities + similarities = [] + for i, doc_embedding in enumerate(self.embeddings): + similarity = np.dot(query_embedding, doc_embedding) + similarities.append((self.documents[i], similarity)) + + # Return top-k most similar documents + similarities.sort(key=lambda x: x[1], reverse=True) + return similarities[:top_k] + + async def get_embedding(self, text: str): + async with httpx.AsyncClient() as client: + response = await client.post( + "http://localhost:8080/v1/embeddings", + json={ + "input": text, + "model": self.embedding_model + } + ) + return response.json() + +# Usage +rag = SimpleRAG("sentence-transformers/all-MiniLM-L6-v2") + +# Add documents +await rag.add_document("The capital of France is Paris.") +await rag.add_document("Python is a programming language.") +await rag.add_document("Machine learning is a subset of AI.") + +# Search +results = await rag.search("What is the capital of France?") +print(results[0][0]) # Should return the document about Paris +``` + +## Troubleshooting + +### Common Issues + +1. **"Model not support embeddings"** + - Ensure vLLM is started with `--task embed` flag + - Verify the model supports embedding generation + +2. **"Connection refused"** + - Check that vLLM server is running on the specified port + - Verify `INFERENCE_ENGINE_ENDPOINT` environment variable + +3. **Out of memory errors** + - Reduce batch size + - Use a smaller model + - Increase GPU memory allocation + +4. **Slow performance** + - Use GPU acceleration if available + - Implement request batching + - Consider model quantization + +### Debugging + +Enable debug logging for more detailed error information: + +```python +import logging + +# Enable debug logging +logging.basicConfig(level=logging.DEBUG) + +# Your embedding requests will now show detailed logs +``` + +## Next Steps + +- Explore different embedding models for your use case +- Implement caching for production deployments +- Set up monitoring and metrics collection +- Consider implementing custom preprocessing for your domain + +For more information, see the [AIBrix documentation](https://aibrix.readthedocs.io/) and [vLLM embedding guide](https://docs.vllm.ai/). \ No newline at end of file diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go index 0e57ea540..1b9882f4e 100644 --- a/pkg/plugins/gateway/util.go +++ b/pkg/plugins/gateway/util.go @@ -69,6 +69,29 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user } model = completionObj.Model message = completionObj.Prompt + } else if requestPath == "/v1/embeddings" { + var embeddingReq struct { + Input interface{} `json:"input"` + Model string `json:"model"` + } + if err := json.Unmarshal(requestBody, &embeddingReq); err != nil { + klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody)) + errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true") + return + } + model = embeddingReq.Model + // Convert input to string for message + switch v := embeddingReq.Input.(type) { + case string: + message = v + case []interface{}: + // Handle array inputs + if len(v) > 0 { + if str, ok := v[0].(string); ok { + message = str + } + } + } } else { errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true") return diff --git a/python/aibrix/aibrix/app.py b/python/aibrix/aibrix/app.py index a16dd1189..529f72e8d 100644 --- a/python/aibrix/aibrix/app.py +++ b/python/aibrix/aibrix/app.py @@ -29,6 +29,7 @@ from aibrix.openapi.model import ModelManager from aibrix.openapi.protocol import ( DownloadModelRequest, + EmbeddingRequest, ErrorResponse, ListModelRequest, LoadLoraAdapterRequest, @@ -188,6 +189,14 @@ async def readiness_check(): return JSONResponse(content={"status": "not ready"}, status_code=503) +@router.post("/v1/embeddings") +async def create_embeddings(request: EmbeddingRequest, raw_request: Request): + response = await inference_engine(raw_request).create_embeddings(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), status_code=response.code) + return JSONResponse(status_code=200, content=response.model_dump()) + + def build_app(args: argparse.Namespace): if args.enable_fastapi_docs: app = FastAPI(debug=False) diff --git a/python/aibrix/aibrix/openapi/engine/base.py b/python/aibrix/aibrix/openapi/engine/base.py index c3604cd89..3c5124908 100644 --- a/python/aibrix/aibrix/openapi/engine/base.py +++ b/python/aibrix/aibrix/openapi/engine/base.py @@ -20,6 +20,8 @@ from packaging.version import Version from aibrix.openapi.protocol import ( + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, @@ -71,6 +73,16 @@ async def list_models(self) -> Union[ErrorResponse, str]: status_code=HTTPStatus.NOT_IMPLEMENTED, ) + async def create_embeddings( + self, request: EmbeddingRequest + ) -> Union[ErrorResponse, EmbeddingResponse]: + return self._create_error_response( + f"Inference engine {self.name} with version {self.version} " + "not support embeddings", + err_type="NotImplementedError", + status_code=HTTPStatus.NOT_IMPLEMENTED, + ) + def get_inference_engine(engine: str, version: str, endpoint: str) -> InferenceEngine: if engine.lower() == "vllm": diff --git a/python/aibrix/aibrix/openapi/engine/vllm.py b/python/aibrix/aibrix/openapi/engine/vllm.py index 2c9add91a..5903a7d24 100644 --- a/python/aibrix/aibrix/openapi/engine/vllm.py +++ b/python/aibrix/aibrix/openapi/engine/vllm.py @@ -21,6 +21,8 @@ from aibrix.logger import init_logger from aibrix.openapi.engine.base import InferenceEngine from aibrix.openapi.protocol import ( + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest, @@ -139,3 +141,29 @@ async def list_models(self) -> Union[ErrorResponse, str]: err_type="ServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) + + async def create_embeddings( + self, request: EmbeddingRequest + ) -> Union[ErrorResponse, EmbeddingResponse]: + embeddings_url = urljoin(self.endpoint, "/v1/embeddings") + + try: + response = await self.client.post( + embeddings_url, json=request.model_dump(), headers=self.headers + ) + except Exception as e: + logger.error(f"Failed to create embeddings: {e}") + return self._create_error_response( + "Failed to create embeddings", + err_type="ServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + if response.status_code != HTTPStatus.OK: + return self._create_error_response( + f"Failed to create embeddings: {response.text}", + err_type="ServerError", + status_code=HTTPStatus(value=response.status_code), + ) + + return EmbeddingResponse(**response.json()) diff --git a/python/aibrix/aibrix/openapi/protocol.py b/python/aibrix/aibrix/openapi/protocol.py index 47d34d9c4..0951f7f63 100644 --- a/python/aibrix/aibrix/openapi/protocol.py +++ b/python/aibrix/aibrix/openapi/protocol.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -66,3 +66,29 @@ class ListModelRequest(NoExtraBaseModel): class ListModelResponse(NoExtraBaseModel): object: str = "list" data: List[ModelStatusCard] = Field(default_factory=list) + + +class EmbeddingRequest(NoExtraBaseModel): + input: Union[str, List[str], List[int], List[List[int]]] + model: str + encoding_format: Optional[Literal["float", "base64"]] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + + +class EmbeddingData(NoExtraBaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], str] # float array or base64 string + index: int + + +class EmbeddingUsage(NoExtraBaseModel): + prompt_tokens: int + total_tokens: int + + +class EmbeddingResponse(NoExtraBaseModel): + object: Literal["list"] = "list" + data: List[EmbeddingData] + model: str + usage: EmbeddingUsage diff --git a/python/aibrix/tests/test_embedding_integration.py b/python/aibrix/tests/test_embedding_integration.py new file mode 100644 index 000000000..106e09c04 --- /dev/null +++ b/python/aibrix/tests/test_embedding_integration.py @@ -0,0 +1,406 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi.testclient import TestClient + +from aibrix.app import build_app +from aibrix.openapi.engine.vllm import VLLMInferenceEngine +from aibrix.openapi.protocol import ( + EmbeddingData, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingUsage, + ErrorResponse, +) + + +class TestVLLMInferenceEngineEmbeddings: + def setup_method(self): + """Set up test fixtures.""" + self.engine = VLLMInferenceEngine( + name="vllm", + version="0.6.1", + endpoint="http://localhost:8000", + ) + + @pytest.mark.asyncio + async def test_create_embeddings_success(self): + """Test successful embeddings creation.""" + # Mock the VLLM response + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3, 0.4], + "index": 0, + } + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 4, "total_tokens": 4}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="Hello world", + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + # Verify the result + assert isinstance(result, EmbeddingResponse) + assert result.object == "list" + assert len(result.data) == 1 + assert result.data[0].embedding == [0.1, 0.2, 0.3, 0.4] + assert result.model == "text-embedding-ada-002" + assert result.usage.prompt_tokens == 4 + + # Verify the HTTP call + mock_post.assert_called_once_with( + "http://localhost:8000/v1/embeddings", + json=request.model_dump(), + headers=self.engine.headers, + ) + + @pytest.mark.asyncio + async def test_create_embeddings_batch_input(self): + """Test embeddings creation with batch input.""" + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2, 0.3], + "index": 0, + }, + { + "object": "embedding", + "embedding": [0.4, 0.5, 0.6], + "index": 1, + }, + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 8, "total_tokens": 8}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, EmbeddingResponse) + assert len(result.data) == 2 + assert result.data[0].index == 0 + assert result.data[1].index == 1 + + @pytest.mark.asyncio + async def test_create_embeddings_http_error(self): + """Test embeddings creation with HTTP error.""" + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.BAD_REQUEST + mock_http_response.text = "Invalid model" + + with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="Hello world", + model="invalid-model", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, ErrorResponse) + assert result.type == "ServerError" + assert "Failed to create embeddings: Invalid model" in result.message + assert result.code == HTTPStatus.BAD_REQUEST + + @pytest.mark.asyncio + async def test_create_embeddings_network_error(self): + """Test embeddings creation with network error.""" + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.side_effect = httpx.ConnectError("Connection failed") + + request = EmbeddingRequest( + input="Hello world", + model="text-embedding-ada-002", + ) + + result = await self.engine.create_embeddings(request) + + assert isinstance(result, ErrorResponse) + assert result.type == "ServerError" + assert result.message == "Failed to create embeddings" + assert result.code == HTTPStatus.INTERNAL_SERVER_ERROR + + @pytest.mark.asyncio + async def test_create_embeddings_with_api_key(self): + """Test embeddings creation with API key authentication.""" + engine_with_key = VLLMInferenceEngine( + name="vllm", + version="0.6.1", + endpoint="http://localhost:8000", + api_key="test-api-key", + ) + + mock_response = { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.1, 0.2], + "index": 0, + } + ], + "model": "text-embedding-ada-002", + "usage": {"prompt_tokens": 2, "total_tokens": 2}, + } + + mock_http_response = MagicMock() + mock_http_response.status_code = HTTPStatus.OK + mock_http_response.json.return_value = mock_response + + with patch.object( + engine_with_key.client, "post", new_callable=AsyncMock + ) as mock_post: + mock_post.return_value = mock_http_response + + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + ) + + result = await engine_with_key.create_embeddings(request) + + assert isinstance(result, EmbeddingResponse) + # Verify that the client was created with the Authorization header + assert "Authorization" in engine_with_key.client.headers + assert engine_with_key.client.headers["Authorization"] == "Bearer test-api-key" + + +class TestEmbeddingsAPIEndpoint: + @pytest.fixture + def app(self): + """Create a test FastAPI app.""" + import argparse + + args = argparse.Namespace(enable_fastapi_docs=True) + return build_app(args) + + @pytest.fixture + def client(self, app): + """Create a test client.""" + return TestClient(app) + + def test_embeddings_endpoint_success(self, client): + """Test the /v1/embeddings endpoint with successful response.""" + # Mock the inference engine + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "text-embedding-ada-002", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["embedding"] == [0.1, 0.2, 0.3] + assert data["model"] == "text-embedding-ada-002" + + def test_embeddings_endpoint_error(self, client): + """Test the /v1/embeddings endpoint with error response.""" + mock_error = ErrorResponse( + message="Model not found", + type="NotFoundError", + code=404, + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_error) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "non-existent-model", + }, + ) + + assert response.status_code == 404 + data = response.json() + assert data["message"] == "Model not found" + assert data["type"] == "NotFoundError" + + def test_embeddings_endpoint_validation_error(self, client): + """Test the /v1/embeddings endpoint with validation error.""" + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + # Missing required 'model' field + }, + ) + + assert response.status_code == 422 # Validation error + data = response.json() + assert "detail" in data + + def test_embeddings_endpoint_different_input_types(self, client): + """Test the endpoint with different input types.""" + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2], index=0), + EmbeddingData(embedding=[0.3, 0.4], index=1), + ], + model="test-model", + usage=EmbeddingUsage(prompt_tokens=4, total_tokens=4), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + # Test with string array + response = client.post( + "/v1/embeddings", + json={ + "input": ["Hello", "World"], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + # Test with token array + response = client.post( + "/v1/embeddings", + json={ + "input": [101, 102, 103], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + # Test with nested token array + response = client.post( + "/v1/embeddings", + json={ + "input": [[101, 102], [103, 104]], + "model": "test-model", + }, + ) + assert response.status_code == 200 + + def test_embeddings_endpoint_optional_parameters(self, client): + """Test the endpoint with optional parameters.""" + mock_response = EmbeddingResponse( + data=[ + EmbeddingData(embedding="base64encodedstring", index=0), + ], + model="test-model", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + + with patch("aibrix.app.inference_engine") as mock_inference_engine: + mock_engine = MagicMock() + mock_engine.create_embeddings = AsyncMock(return_value=mock_response) + mock_inference_engine.return_value = mock_engine + + response = client.post( + "/v1/embeddings", + json={ + "input": "Hello world", + "model": "test-model", + "encoding_format": "base64", + "dimensions": 256, + "user": "test-user", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["data"][0]["embedding"] == "base64encodedstring" + + # Verify the request was passed correctly + call_args = mock_engine.create_embeddings.call_args[0][0] + assert call_args.encoding_format == "base64" + assert call_args.dimensions == 256 + assert call_args.user == "test-user" + + +class TestBaseInferenceEngineEmbeddings: + def test_base_engine_not_implemented(self): + """Test that base inference engine returns NotImplementedError.""" + from aibrix.openapi.engine.base import InferenceEngine + + engine = InferenceEngine( + name="test", + version="1.0", + endpoint="http://localhost:8000", + ) + + request = EmbeddingRequest( + input="test", + model="test-model", + ) + + import asyncio + + result = asyncio.run(engine.create_embeddings(request)) + + assert isinstance(result, ErrorResponse) + assert result.type == "NotImplementedError" + assert result.code == HTTPStatus.NOT_IMPLEMENTED + assert "not support embeddings" in result.message \ No newline at end of file diff --git a/python/aibrix/tests/test_embedding_protocol.py b/python/aibrix/tests/test_embedding_protocol.py new file mode 100644 index 000000000..b9a96f587 --- /dev/null +++ b/python/aibrix/tests/test_embedding_protocol.py @@ -0,0 +1,293 @@ +# Copyright 2024 The Aibrix Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import ValidationError + +from aibrix.openapi.protocol import ( + EmbeddingData, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingUsage, +) + + +class TestEmbeddingRequest: + def test_valid_string_input(self): + """Test EmbeddingRequest with valid string input.""" + request = EmbeddingRequest( + input="Hello, world!", + model="text-embedding-ada-002", + ) + assert request.input == "Hello, world!" + assert request.model == "text-embedding-ada-002" + assert request.encoding_format == "float" # default value + assert request.dimensions is None + assert request.user is None + + def test_valid_string_array_input(self): + """Test EmbeddingRequest with string array input.""" + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + ) + assert request.input == ["Hello", "World"] + assert request.model == "text-embedding-ada-002" + + def test_valid_token_array_input(self): + """Test EmbeddingRequest with token array input.""" + request = EmbeddingRequest( + input=[101, 102, 103], + model="text-embedding-ada-002", + ) + assert request.input == [101, 102, 103] + + def test_valid_nested_token_array_input(self): + """Test EmbeddingRequest with nested token array input.""" + request = EmbeddingRequest( + input=[[101, 102], [103, 104]], + model="text-embedding-ada-002", + ) + assert request.input == [[101, 102], [103, 104]] + + def test_optional_parameters(self): + """Test EmbeddingRequest with optional parameters.""" + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + encoding_format="base64", + dimensions=512, + user="user123", + ) + assert request.encoding_format == "base64" + assert request.dimensions == 512 + assert request.user == "user123" + + def test_invalid_encoding_format(self): + """Test EmbeddingRequest with invalid encoding format.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + encoding_format="invalid", + ) + assert "Input should be 'float' or 'base64'" in str(exc_info.value) + + def test_missing_required_fields(self): + """Test EmbeddingRequest with missing required fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest(input="test") + assert "Field required" in str(exc_info.value) + + def test_extra_fields_not_allowed(self): + """Test that extra fields are not allowed.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + extra_field="not_allowed", + ) + assert "Extra inputs are not permitted" in str(exc_info.value) + + def test_model_dump(self): + """Test serialization of EmbeddingRequest.""" + request = EmbeddingRequest( + input="test", + model="text-embedding-ada-002", + dimensions=256, + ) + data = request.model_dump() + assert data["input"] == "test" + assert data["model"] == "text-embedding-ada-002" + assert data["encoding_format"] == "float" + assert data["dimensions"] == 256 + assert data["user"] is None + + +class TestEmbeddingData: + def test_valid_float_embedding(self): + """Test EmbeddingData with float array embedding.""" + data = EmbeddingData( + embedding=[0.1, 0.2, 0.3, 0.4], + index=0, + ) + assert data.object == "embedding" + assert data.embedding == [0.1, 0.2, 0.3, 0.4] + assert data.index == 0 + + def test_valid_base64_embedding(self): + """Test EmbeddingData with base64 string embedding.""" + data = EmbeddingData( + embedding="base64encodedstring", + index=1, + ) + assert data.object == "embedding" + assert data.embedding == "base64encodedstring" + assert data.index == 1 + + def test_object_literal_fixed(self): + """Test that object field is always 'embedding'.""" + data = EmbeddingData( + embedding=[0.1, 0.2], + index=0, + ) + assert data.object == "embedding" + # Cannot override the literal value + + def test_missing_required_fields(self): + """Test EmbeddingData with missing required fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingData(embedding=[0.1, 0.2]) + assert "Field required" in str(exc_info.value) + + +class TestEmbeddingUsage: + def test_valid_usage(self): + """Test EmbeddingUsage with valid token counts.""" + usage = EmbeddingUsage( + prompt_tokens=10, + total_tokens=10, + ) + assert usage.prompt_tokens == 10 + assert usage.total_tokens == 10 + + def test_zero_tokens(self): + """Test EmbeddingUsage with zero tokens.""" + usage = EmbeddingUsage( + prompt_tokens=0, + total_tokens=0, + ) + assert usage.prompt_tokens == 0 + assert usage.total_tokens == 0 + + def test_missing_fields(self): + """Test EmbeddingUsage with missing fields.""" + with pytest.raises(ValidationError) as exc_info: + EmbeddingUsage(prompt_tokens=10) + assert "Field required" in str(exc_info.value) + + +class TestEmbeddingResponse: + def test_valid_response(self): + """Test EmbeddingResponse with valid data.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + EmbeddingData(embedding=[0.4, 0.5, 0.6], index=1), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=6, total_tokens=6), + ) + assert response.object == "list" + assert len(response.data) == 2 + assert response.model == "text-embedding-ada-002" + assert response.usage.prompt_tokens == 6 + assert response.usage.total_tokens == 6 + + def test_empty_data_list(self): + """Test EmbeddingResponse with empty data list.""" + response = EmbeddingResponse( + data=[], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=0, total_tokens=0), + ) + assert len(response.data) == 0 + + def test_model_dump_json(self): + """Test JSON serialization of EmbeddingResponse.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2], index=0), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=3, total_tokens=3), + ) + data = response.model_dump() + assert data["object"] == "list" + assert len(data["data"]) == 1 + assert data["data"][0]["object"] == "embedding" + assert data["data"][0]["embedding"] == [0.1, 0.2] + assert data["data"][0]["index"] == 0 + assert data["model"] == "text-embedding-ada-002" + assert data["usage"]["prompt_tokens"] == 3 + assert data["usage"]["total_tokens"] == 3 + + def test_base64_response(self): + """Test EmbeddingResponse with base64 encoded embeddings.""" + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding="base64string1", index=0), + EmbeddingData(embedding="base64string2", index=1), + ], + model="text-embedding-ada-002", + usage=EmbeddingUsage(prompt_tokens=10, total_tokens=10), + ) + assert response.data[0].embedding == "base64string1" + assert response.data[1].embedding == "base64string2" + + +class TestEmbeddingProtocolIntegration: + def test_request_response_cycle(self): + """Test a complete request-response cycle.""" + # Create a request + request = EmbeddingRequest( + input=["Hello", "World"], + model="text-embedding-ada-002", + encoding_format="float", + ) + + # Simulate processing and create response + response = EmbeddingResponse( + data=[ + EmbeddingData(embedding=[0.1, 0.2, 0.3], index=0), + EmbeddingData(embedding=[0.4, 0.5, 0.6], index=1), + ], + model=request.model, + usage=EmbeddingUsage(prompt_tokens=4, total_tokens=4), + ) + + # Verify the response matches the request + assert response.model == request.model + assert len(response.data) == len(request.input) + + def test_mixed_input_types(self): + """Test that different input types are properly validated.""" + # Valid cases + valid_inputs = [ + "single string", + ["multiple", "strings"], + [1, 2, 3, 4], + [[1, 2], [3, 4]], + ] + + for input_val in valid_inputs: + request = EmbeddingRequest( + input=input_val, + model="test-model", + ) + assert request.input == input_val + + # Invalid cases + invalid_inputs = [ + {"dict": "not allowed"}, + 12.34, # float not allowed + True, # boolean not allowed + ] + + for input_val in invalid_inputs: + with pytest.raises(ValidationError): + EmbeddingRequest( + input=input_val, + model="test-model", + ) \ No newline at end of file From 2ed82b235f9b9f534d5c9941ba1027abf8182199 Mon Sep 17 00:00:00 2001 From: dittops Date: Mon, 23 Jun 2025 06:16:01 +0000 Subject: [PATCH 2/6] fix: adding PathPrefix in httproute Signed-off-by: dittops --- config/gateway/gateway-plugin/gateway-plugin.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/config/gateway/gateway-plugin/gateway-plugin.yaml b/config/gateway/gateway-plugin/gateway-plugin.yaml index d6bd9e0ad..6a06d0a69 100644 --- a/config/gateway/gateway-plugin/gateway-plugin.yaml +++ b/config/gateway/gateway-plugin/gateway-plugin.yaml @@ -165,6 +165,9 @@ spec: - path: type: PathPrefix value: /v1/completions + - path: + type: PathPrefix + value: /v1/embeddings backendRefs: - name: aibrix-gateway-plugins port: 50052 From 3144d7efb8e31d42df828cffb68d9a23148eebc7 Mon Sep 17 00:00:00 2001 From: dittops Date: Mon, 23 Jun 2025 06:50:01 +0000 Subject: [PATCH 3/6] fix: address PR review comments - Add try-except for JSON parsing errors in vllm.py - Use specific httpx.RequestError instead of generic Exception - Improve token array handling in util.go for numeric inputs Signed-off-by: dittops --- pkg/plugins/gateway/util.go | 16 ++++++++++++++-- python/aibrix/aibrix/openapi/engine/vllm.py | 12 ++++++++++-- .../tests/test_embedding_integration.py | 19 ++++++++++++++----- .../aibrix/tests/test_embedding_protocol.py | 2 +- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go index 1b9882f4e..bc2d0831a 100644 --- a/pkg/plugins/gateway/util.go +++ b/pkg/plugins/gateway/util.go @@ -18,6 +18,7 @@ package gateway import ( "encoding/json" + "fmt" "strings" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -87,8 +88,19 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user case []interface{}: // Handle array inputs if len(v) > 0 { - if str, ok := v[0].(string); ok { - message = str + switch elem := v[0].(type) { + case string: + message = elem + case float64: + // Handle token ID (number) + message = fmt.Sprintf("Token array input (first token: %v)", elem) + case []interface{}: + // Handle nested array (number[][]) + if len(elem) > 0 { + if token, ok := elem[0].(float64); ok { + message = fmt.Sprintf("Nested token array input (first token: %v)", token) + } + } } } } diff --git a/python/aibrix/aibrix/openapi/engine/vllm.py b/python/aibrix/aibrix/openapi/engine/vllm.py index 5903a7d24..0eddc2cb6 100644 --- a/python/aibrix/aibrix/openapi/engine/vllm.py +++ b/python/aibrix/aibrix/openapi/engine/vllm.py @@ -151,7 +151,7 @@ async def create_embeddings( response = await self.client.post( embeddings_url, json=request.model_dump(), headers=self.headers ) - except Exception as e: + except httpx.RequestError as e: logger.error(f"Failed to create embeddings: {e}") return self._create_error_response( "Failed to create embeddings", @@ -166,4 +166,12 @@ async def create_embeddings( status_code=HTTPStatus(value=response.status_code), ) - return EmbeddingResponse(**response.json()) + try: + return EmbeddingResponse(**response.json()) + except Exception as e: + logger.error(f"Failed to parse embedding response: {e}") + return self._create_error_response( + "Invalid response from inference engine", + err_type="ServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) diff --git a/python/aibrix/tests/test_embedding_integration.py b/python/aibrix/tests/test_embedding_integration.py index 106e09c04..3069a307f 100644 --- a/python/aibrix/tests/test_embedding_integration.py +++ b/python/aibrix/tests/test_embedding_integration.py @@ -60,7 +60,9 @@ async def test_create_embeddings_success(self): mock_http_response.status_code = HTTPStatus.OK mock_http_response.json.return_value = mock_response - with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: mock_post.return_value = mock_http_response request = EmbeddingRequest( @@ -110,7 +112,9 @@ async def test_create_embeddings_batch_input(self): mock_http_response.status_code = HTTPStatus.OK mock_http_response.json.return_value = mock_response - with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: mock_post.return_value = mock_http_response request = EmbeddingRequest( @@ -132,7 +136,9 @@ async def test_create_embeddings_http_error(self): mock_http_response.status_code = HTTPStatus.BAD_REQUEST mock_http_response.text = "Invalid model" - with patch.object(self.engine.client, "post", new_callable=AsyncMock) as mock_post: + with patch.object( + self.engine.client, "post", new_callable=AsyncMock + ) as mock_post: mock_post.return_value = mock_http_response request = EmbeddingRequest( @@ -209,7 +215,10 @@ async def test_create_embeddings_with_api_key(self): assert isinstance(result, EmbeddingResponse) # Verify that the client was created with the Authorization header assert "Authorization" in engine_with_key.client.headers - assert engine_with_key.client.headers["Authorization"] == "Bearer test-api-key" + assert ( + engine_with_key.client.headers["Authorization"] + == "Bearer test-api-key" + ) class TestEmbeddingsAPIEndpoint: @@ -403,4 +412,4 @@ def test_base_engine_not_implemented(self): assert isinstance(result, ErrorResponse) assert result.type == "NotImplementedError" assert result.code == HTTPStatus.NOT_IMPLEMENTED - assert "not support embeddings" in result.message \ No newline at end of file + assert "not support embeddings" in result.message diff --git a/python/aibrix/tests/test_embedding_protocol.py b/python/aibrix/tests/test_embedding_protocol.py index b9a96f587..e1a21ab75 100644 --- a/python/aibrix/tests/test_embedding_protocol.py +++ b/python/aibrix/tests/test_embedding_protocol.py @@ -290,4 +290,4 @@ def test_mixed_input_types(self): EmbeddingRequest( input=input_val, model="test-model", - ) \ No newline at end of file + ) From e72d978df64ebc6f7c621b7b2e48822f7ced508b Mon Sep 17 00:00:00 2001 From: dittops Date: Mon, 23 Jun 2025 07:17:05 +0000 Subject: [PATCH 4/6] fix: formatting issues in test files Signed-off-by: dittops --- python/aibrix/tests/test_embedding_integration.py | 1 + python/aibrix/tests/test_embedding_protocol.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/aibrix/tests/test_embedding_integration.py b/python/aibrix/tests/test_embedding_integration.py index 3069a307f..d4dcd23c6 100644 --- a/python/aibrix/tests/test_embedding_integration.py +++ b/python/aibrix/tests/test_embedding_integration.py @@ -413,3 +413,4 @@ def test_base_engine_not_implemented(self): assert result.type == "NotImplementedError" assert result.code == HTTPStatus.NOT_IMPLEMENTED assert "not support embeddings" in result.message + diff --git a/python/aibrix/tests/test_embedding_protocol.py b/python/aibrix/tests/test_embedding_protocol.py index e1a21ab75..e1ee47b00 100644 --- a/python/aibrix/tests/test_embedding_protocol.py +++ b/python/aibrix/tests/test_embedding_protocol.py @@ -291,3 +291,4 @@ def test_mixed_input_types(self): input=input_val, model="test-model", ) + From a4ff6a6f747d1165c605a3a0d1e3609c6af37ac3 Mon Sep 17 00:00:00 2001 From: dittops Date: Mon, 23 Jun 2025 07:19:41 +0000 Subject: [PATCH 5/6] remove md file Signed-off-by: dittops --- EMBEDDINGS_GUIDE.md | 476 -------------------------------------------- 1 file changed, 476 deletions(-) delete mode 100644 EMBEDDINGS_GUIDE.md diff --git a/EMBEDDINGS_GUIDE.md b/EMBEDDINGS_GUIDE.md deleted file mode 100644 index 155863e37..000000000 --- a/EMBEDDINGS_GUIDE.md +++ /dev/null @@ -1,476 +0,0 @@ -# AIBrix Embeddings API Guide - -This guide provides comprehensive documentation for using the `/v1/embeddings` endpoint in AIBrix. - -## Overview - -The `/v1/embeddings` endpoint enables you to generate vector embeddings from text inputs using various embedding models. This endpoint follows the OpenAI embeddings API specification, ensuring compatibility with existing tools and libraries. - -## Table of Contents - -- [Prerequisites](#prerequisites) -- [API Reference](#api-reference) -- [Usage Examples](#usage-examples) -- [Configuration](#configuration) -- [Error Handling](#error-handling) -- [Performance Considerations](#performance-considerations) - -## Prerequisites - -### vLLM Configuration - -To use embeddings with AIBrix, you need: - -1. **vLLM version 0.4.3 or higher** -2. **Models loaded with embedding task support** - -```bash -# Start vLLM with embedding support -python -m vllm.entrypoints.openai.api_server \ - --model sentence-transformers/all-MiniLM-L6-v2 \ - --task embed \ - --port 8000 -``` - -### Supported Models - -Common embedding models that work with vLLM: -- `sentence-transformers/all-MiniLM-L6-v2` -- `sentence-transformers/all-mpnet-base-v2` -- `intfloat/e5-large-v2` -- `BAAI/bge-large-en-v1.5` - -## API Reference - -### Request Format - -```http -POST /v1/embeddings -Content-Type: application/json - -{ - "input": "string | string[] | number[] | number[][]", - "model": "string", - "encoding_format": "float | base64", // optional, default: "float" - "dimensions": "number", // optional - "user": "string" // optional -} -``` - -### Response Format - -```json -{ - "object": "list", - "data": [ - { - "object": "embedding", - "embedding": [0.1, 0.2, 0.3, ...], - "index": 0 - } - ], - "model": "sentence-transformers/all-MiniLM-L6-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } -} -``` - -### Parameters - -| Parameter | Type | Required | Description | -|-----------|------|----------|-------------| -| `input` | `string \| string[] \| number[] \| number[][]` | Yes | Text input(s) to generate embeddings for | -| `model` | `string` | Yes | ID of the model to use | -| `encoding_format` | `string` | No | Format to return embeddings in: `"float"` or `"base64"` | -| `dimensions` | `integer` | No | Number of dimensions for the embedding (model-dependent) | -| `user` | `string` | No | Unique identifier for the user | - -## Usage Examples - -### Single Text Input - -```python -import httpx - -async def get_embedding(): - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={ - "input": "The quick brown fox jumps over the lazy dog", - "model": "sentence-transformers/all-MiniLM-L6-v2" - } - ) - return response.json() - -# Example response: -# { -# "object": "list", -# "data": [ -# { -# "object": "embedding", -# "embedding": [0.012, -0.045, 0.123, ...], -# "index": 0 -# } -# ], -# "model": "sentence-transformers/all-MiniLM-L6-v2", -# "usage": { -# "prompt_tokens": 9, -# "total_tokens": 9 -# } -# } -``` - -### Batch Text Inputs - -```python -async def get_batch_embeddings(): - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={ - "input": [ - "Hello world", - "How are you?", - "Goodbye!" - ], - "model": "sentence-transformers/all-MiniLM-L6-v2" - } - ) - return response.json() - -# Example response: -# { -# "object": "list", -# "data": [ -# { -# "object": "embedding", -# "embedding": [0.012, -0.045, ...], -# "index": 0 -# }, -# { -# "object": "embedding", -# "embedding": [0.034, -0.067, ...], -# "index": 1 -# }, -# { -# "object": "embedding", -# "embedding": [0.056, -0.089, ...], -# "index": 2 -# } -# ], -# "model": "sentence-transformers/all-MiniLM-L6-v2", -# "usage": { -# "prompt_tokens": 6, -# "total_tokens": 6 -# } -# } -``` - -### Base64 Encoding - -```python -async def get_base64_embeddings(): - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={ - "input": "Convert this to base64", - "model": "sentence-transformers/all-MiniLM-L6-v2", - "encoding_format": "base64" - } - ) - return response.json() - -# The embedding will be returned as a base64-encoded string -``` - -### Token Array Input - -```python -async def get_token_embeddings(): - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={ - "input": [101, 7592, 2088, 102], # Tokenized input - "model": "sentence-transformers/all-MiniLM-L6-v2" - } - ) - return response.json() -``` - -### Using with OpenAI Client - -```python -from openai import OpenAI - -# Configure client to use AIBrix -client = OpenAI( - base_url="http://localhost:8080/v1", - api_key="not-needed" # AIBrix doesn't require API key by default -) - -# Generate embeddings -response = client.embeddings.create( - input="Your text here", - model="sentence-transformers/all-MiniLM-L6-v2" -) - -embedding = response.data[0].embedding -print(f"Embedding dimension: {len(embedding)}") -``` - -## Configuration - -### Environment Variables - -Configure AIBrix for embeddings support: - -```bash -# Set the inference engine -export INFERENCE_ENGINE=vllm -export INFERENCE_ENGINE_VERSION=0.6.1 -export INFERENCE_ENGINE_ENDPOINT=http://localhost:8000 - -# Optional: Set routing strategy -export ROUTING_ALGORITHM=random -``` - -### Model Configuration - -Ensure your embedding model is properly configured in vLLM: - -```bash -# Example vLLM startup with specific parameters -python -m vllm.entrypoints.openai.api_server \ - --model sentence-transformers/all-MiniLM-L6-v2 \ - --task embed \ - --port 8000 \ - --max-model-len 512 \ - --trust-remote-code -``` - -## Error Handling - -### Common Error Responses - -#### Model Not Found (404) -```json -{ - "object": "error", - "message": "Model 'non-existent-model' not found", - "type": "NotFoundError", - "code": 404 -} -``` - -#### Invalid Input Format (400) -```json -{ - "object": "error", - "message": "Invalid input format", - "type": "BadRequestError", - "code": 400 -} -``` - -#### Model Doesn't Support Embeddings (501) -```json -{ - "object": "error", - "message": "Inference engine vllm with version 0.6.1 not support embeddings", - "type": "NotImplementedError", - "code": 501 -} -``` - -### Error Handling in Code - -```python -async def safe_get_embeddings(text: str, model: str): - try: - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={"input": text, "model": model} - ) - - if response.status_code == 200: - return response.json() - else: - error_data = response.json() - print(f"Error {response.status_code}: {error_data['message']}") - return None - - except httpx.RequestError as e: - print(f"Network error: {e}") - return None -``` - -## Performance Considerations - -### Batch Processing - -For better performance, batch multiple inputs together: - -```python -# Instead of making multiple single requests -texts = ["text1", "text2", "text3", "text4", "text5"] - -# Batch them together -response = await client.post( - "/v1/embeddings", - json={ - "input": texts, # Send all at once - "model": "your-model" - } -) -``` - -### Optimal Batch Sizes - -- **Small models**: 50-100 texts per batch -- **Large models**: 10-20 texts per batch -- **Very large models**: 1-5 texts per batch - -Monitor memory usage and adjust accordingly. - -### Caching - -Consider caching embeddings for frequently used texts: - -```python -import hashlib -from typing import Dict, List - -class EmbeddingCache: - def __init__(self): - self.cache: Dict[str, List[float]] = {} - - def get_cache_key(self, text: str, model: str) -> str: - return hashlib.md5(f"{text}:{model}".encode()).hexdigest() - - async def get_embedding(self, text: str, model: str) -> List[float]: - cache_key = self.get_cache_key(text, model) - - if cache_key in self.cache: - return self.cache[cache_key] - - # Get embedding from API - response = await self.fetch_embedding(text, model) - embedding = response["data"][0]["embedding"] - - # Cache the result - self.cache[cache_key] = embedding - return embedding -``` - -## RAG Integration Example - -Here's how to use embeddings in a Retrieval-Augmented Generation (RAG) system: - -```python -import numpy as np -from typing import List, Tuple - -class SimpleRAG: - def __init__(self, embedding_model: str): - self.embedding_model = embedding_model - self.documents: List[str] = [] - self.embeddings: List[List[float]] = [] - - async def add_document(self, text: str): - """Add a document to the knowledge base.""" - # Get embedding for the document - response = await self.get_embedding(text) - embedding = response["data"][0]["embedding"] - - self.documents.append(text) - self.embeddings.append(embedding) - - async def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]: - """Search for relevant documents.""" - # Get query embedding - response = await self.get_embedding(query) - query_embedding = np.array(response["data"][0]["embedding"]) - - # Calculate similarities - similarities = [] - for i, doc_embedding in enumerate(self.embeddings): - similarity = np.dot(query_embedding, doc_embedding) - similarities.append((self.documents[i], similarity)) - - # Return top-k most similar documents - similarities.sort(key=lambda x: x[1], reverse=True) - return similarities[:top_k] - - async def get_embedding(self, text: str): - async with httpx.AsyncClient() as client: - response = await client.post( - "http://localhost:8080/v1/embeddings", - json={ - "input": text, - "model": self.embedding_model - } - ) - return response.json() - -# Usage -rag = SimpleRAG("sentence-transformers/all-MiniLM-L6-v2") - -# Add documents -await rag.add_document("The capital of France is Paris.") -await rag.add_document("Python is a programming language.") -await rag.add_document("Machine learning is a subset of AI.") - -# Search -results = await rag.search("What is the capital of France?") -print(results[0][0]) # Should return the document about Paris -``` - -## Troubleshooting - -### Common Issues - -1. **"Model not support embeddings"** - - Ensure vLLM is started with `--task embed` flag - - Verify the model supports embedding generation - -2. **"Connection refused"** - - Check that vLLM server is running on the specified port - - Verify `INFERENCE_ENGINE_ENDPOINT` environment variable - -3. **Out of memory errors** - - Reduce batch size - - Use a smaller model - - Increase GPU memory allocation - -4. **Slow performance** - - Use GPU acceleration if available - - Implement request batching - - Consider model quantization - -### Debugging - -Enable debug logging for more detailed error information: - -```python -import logging - -# Enable debug logging -logging.basicConfig(level=logging.DEBUG) - -# Your embedding requests will now show detailed logs -``` - -## Next Steps - -- Explore different embedding models for your use case -- Implement caching for production deployments -- Set up monitoring and metrics collection -- Consider implementing custom preprocessing for your domain - -For more information, see the [AIBrix documentation](https://aibrix.readthedocs.io/) and [vLLM embedding guide](https://docs.vllm.ai/). \ No newline at end of file From 1004019e93a9da9dbf1a2b61156384f544832b71 Mon Sep 17 00:00:00 2001 From: dittops Date: Wed, 25 Jun 2025 06:14:42 +0000 Subject: [PATCH 6/6] fix: add path for httproute --- pkg/controller/modelrouter/modelrouter_controller.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/controller/modelrouter/modelrouter_controller.go b/pkg/controller/modelrouter/modelrouter_controller.go index 6ead4a856..7855343a4 100644 --- a/pkg/controller/modelrouter/modelrouter_controller.go +++ b/pkg/controller/modelrouter/modelrouter_controller.go @@ -228,6 +228,15 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string modelHeaderMatch, }, }, + { + Path: &gatewayv1.HTTPPathMatch{ + Type: ptr.To(gatewayv1.PathMatchPathPrefix), + Value: ptr.To("/v1/embeddings"), + }, + Headers: []gatewayv1.HTTPHeaderMatch{ + modelHeaderMatch, + }, + }, }, BackendRefs: []gatewayv1.HTTPBackendRef{ {