Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/gateway/gateway-plugin/gateway-plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ spec:
- path:
type: PathPrefix
value: /v1/completions
- path:
type: PathPrefix
value: /v1/embeddings
backendRefs:
- name: aibrix-gateway-plugins
port: 50052
Expand Down
9 changes: 9 additions & 0 deletions pkg/controller/modelrouter/modelrouter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,15 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string
modelHeaderMatch,
},
},
{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To("/v1/embeddings"),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
},
},
BackendRefs: []gatewayv1.HTTPBackendRef{
{
Expand Down
35 changes: 35 additions & 0 deletions pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package gateway

import (
"encoding/json"
"fmt"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
Expand Down Expand Up @@ -69,6 +70,40 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
}
model = completionObj.Model
message = completionObj.Prompt
} else if requestPath == "/v1/embeddings" {
var embeddingReq struct {
Input interface{} `json:"input"`
Model string `json:"model"`
}
if err := json.Unmarshal(requestBody, &embeddingReq); err != nil {
klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody))
errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true")
return
}
model = embeddingReq.Model
// Convert input to string for message
switch v := embeddingReq.Input.(type) {
case string:
message = v
case []interface{}:
// Handle array inputs
if len(v) > 0 {
switch elem := v[0].(type) {
case string:
message = elem
case float64:
// Handle token ID (number)
message = fmt.Sprintf("Token array input (first token: %v)", elem)
case []interface{}:
// Handle nested array (number[][])
if len(elem) > 0 {
if token, ok := elem[0].(float64); ok {
message = fmt.Sprintf("Nested token array input (first token: %v)", token)
}
}
}
}
}
} else {
errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true")
return
Expand Down
9 changes: 9 additions & 0 deletions python/aibrix/aibrix/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from aibrix.openapi.model import ModelManager
from aibrix.openapi.protocol import (
DownloadModelRequest,
EmbeddingRequest,
ErrorResponse,
ListModelRequest,
LoadLoraAdapterRequest,
Expand Down Expand Up @@ -188,6 +189,14 @@ async def readiness_check():
return JSONResponse(content={"status": "not ready"}, status_code=503)


@router.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingRequest, raw_request: Request):
response = await inference_engine(raw_request).create_embeddings(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.code)
return JSONResponse(status_code=200, content=response.model_dump())


def build_app(args: argparse.Namespace):
if args.enable_fastapi_docs:
app = FastAPI(debug=False)
Expand Down
12 changes: 12 additions & 0 deletions python/aibrix/aibrix/openapi/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from packaging.version import Version

from aibrix.openapi.protocol import (
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest,
Expand Down Expand Up @@ -71,6 +73,16 @@ async def list_models(self) -> Union[ErrorResponse, str]:
status_code=HTTPStatus.NOT_IMPLEMENTED,
)

async def create_embeddings(
self, request: EmbeddingRequest
) -> Union[ErrorResponse, EmbeddingResponse]:
return self._create_error_response(
f"Inference engine {self.name} with version {self.version} "
"not support embeddings",
err_type="NotImplementedError",
status_code=HTTPStatus.NOT_IMPLEMENTED,
)


def get_inference_engine(engine: str, version: str, endpoint: str) -> InferenceEngine:
if engine.lower() == "vllm":
Expand Down
36 changes: 36 additions & 0 deletions python/aibrix/aibrix/openapi/engine/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from aibrix.logger import init_logger
from aibrix.openapi.engine.base import InferenceEngine
from aibrix.openapi.protocol import (
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest,
Expand Down Expand Up @@ -139,3 +141,37 @@ async def list_models(self) -> Union[ErrorResponse, str]:
err_type="ServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)

async def create_embeddings(
self, request: EmbeddingRequest
) -> Union[ErrorResponse, EmbeddingResponse]:
embeddings_url = urljoin(self.endpoint, "/v1/embeddings")

try:
response = await self.client.post(
embeddings_url, json=request.model_dump(), headers=self.headers
)
except httpx.RequestError as e:
logger.error(f"Failed to create embeddings: {e}")
return self._create_error_response(
"Failed to create embeddings",
err_type="ServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)

if response.status_code != HTTPStatus.OK:
return self._create_error_response(
f"Failed to create embeddings: {response.text}",
err_type="ServerError",
status_code=HTTPStatus(value=response.status_code),
)

try:
return EmbeddingResponse(**response.json())
except Exception as e:
logger.error(f"Failed to parse embedding response: {e}")
return self._create_error_response(
"Invalid response from inference engine",
err_type="ServerError",
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
28 changes: 27 additions & 1 deletion python/aibrix/aibrix/openapi/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Dict, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -66,3 +66,29 @@ class ListModelRequest(NoExtraBaseModel):
class ListModelResponse(NoExtraBaseModel):
object: str = "list"
data: List[ModelStatusCard] = Field(default_factory=list)


class EmbeddingRequest(NoExtraBaseModel):
input: Union[str, List[str], List[int], List[List[int]]]
model: str
encoding_format: Optional[Literal["float", "base64"]] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None


class EmbeddingData(NoExtraBaseModel):
object: Literal["embedding"] = "embedding"
embedding: Union[List[float], str] # float array or base64 string
index: int


class EmbeddingUsage(NoExtraBaseModel):
prompt_tokens: int
total_tokens: int


class EmbeddingResponse(NoExtraBaseModel):
object: Literal["list"] = "list"
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
Loading