diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 1ef5f413973..4b94a1a9394 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -87,16 +87,16 @@ class MetadataServerConfig(): refresh_interval: float = 10.0 -def get_ctx_gen_server_urls( +def get_ctx_gen_server_addrs( server_configs: list[CtxGenServerConfig] ) -> tuple[list[str], list[str]]: ctx_server_urls = [] gen_server_urls = [] for cfg in server_configs: if cfg.type == "ctx": - ctx_server_urls.append(f"http://{cfg.hostname}:{cfg.port}") + ctx_server_urls.append(f"{cfg.hostname}:{cfg.port}") else: - gen_server_urls.append(f"http://{cfg.hostname}:{cfg.port}") + gen_server_urls.append(f"{cfg.hostname}:{cfg.port}") return ctx_server_urls, gen_server_urls diff --git a/tensorrt_llm/serve/disagg_auto_scaling.py b/tensorrt_llm/serve/disagg_auto_scaling.py index b30caed46b6..62a7b5bc400 100644 --- a/tensorrt_llm/serve/disagg_auto_scaling.py +++ b/tensorrt_llm/serve/disagg_auto_scaling.py @@ -4,7 +4,7 @@ import random import time from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole from tensorrt_llm.logger import logger @@ -44,6 +44,7 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage): self._current_ctx_workers = {} # worker_id -> WorkerInfo self._current_gen_workers = {} # worker_id -> WorkerInfo self._watch_handle = None + self._watch_task = None def __del__(self): try: @@ -92,7 +93,14 @@ def current_gen_worker_num(self) -> int: def worker_key_prefix(self) -> str: return get_worker_key_prefix(self._config.cluster_name) - async def watch_workers(self, get_existing_first: bool = True): + async def watch_workers( + self, + get_existing_first: bool = True, + on_event: Optional[Callable[[WorkerInfo, WatchEventType], + Awaitable[Any]]] = None): + if self._watch_handle: + logger.error("Watch handle is already initialized") + return [] workers = [] self._watch_handle = await self._cluster_storage.watch( self.worker_key_prefix) @@ -109,12 +117,41 @@ async def watch_workers(self, get_existing_first: bool = True): workers.append(self._parse_worker_info(event)) events.append(event) await self._watch_handle.add_events(events) + + self._watch_handle = await self._cluster_storage.watch( + self.worker_key_prefix) + + async def worker_event_loop(): + logger.info( + f"Start watching worker events with {len(workers)} existing workers" + ) + for worker_info in workers: + await on_event(worker_info, WatchEventType.SET) + while True: + try: + worker_events = await self._watch_handle.drain() + for event in worker_events: + worker_info = self._parse_worker_info(event) + await on_event(worker_info, event.event_type) + except asyncio.CancelledError: + break + except Exception as e: + logger.error( + f"Error updating routers by worker events: {e}") + await asyncio.sleep(1) + logger.info("Stop watching worker events") + + if on_event: + self._watch_task = asyncio.create_task(worker_event_loop()) return workers async def unwatch_workers(self) -> None: if self._watch_handle: await self._cluster_storage.unwatch(self.worker_key_prefix) self._watch_handle = None + if self._watch_task: + self._watch_task.cancel() + self._watch_task = None async def get_worker_events( self) -> List[Tuple[WorkerInfo, WatchEventType]]: diff --git a/tensorrt_llm/serve/openai_client.py b/tensorrt_llm/serve/openai_client.py new file mode 100644 index 00000000000..e46a2326035 --- /dev/null +++ b/tensorrt_llm/serve/openai_client.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# yapf: disable +import asyncio +import traceback +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Type + +import aiohttp + +from tensorrt_llm.llmapi.disagg_utils import ServerRole +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + UCompletionRequest, + UCompletionResponse, +) +from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector +from tensorrt_llm.serve.responses_utils import ( + ResponseHooks, + UCompletionResponseOrGenerator, + get_steady_clock_now_in_seconds, +) +from tensorrt_llm.serve.router import Router + +# yapf: enable + + +class OpenAIClient(ABC): + async def send_request( + self, + request: UCompletionRequest, + server: Optional[str] = None, + hooks: Optional[ResponseHooks] = None, + ) -> UCompletionResponseOrGenerator: + if isinstance(request, CompletionRequest): + return await self._send_request( + "v1/completions", request, CompletionResponse, server, hooks + ) + elif isinstance(request, ChatCompletionRequest): + return await self._send_request( + "v1/chat/completions", request, ChatCompletionResponse, server, hooks + ) + else: + raise ValueError(f"Invalid request type: {type(request)}") + + @abstractmethod + async def _send_request( + self, + endpoint: str, + request: UCompletionRequest, + response_type: Type[UCompletionResponse], + server: Optional[str] = None, + hooks: Optional[ResponseHooks] = None, + ) -> UCompletionResponseOrGenerator: + """Send a request to the server and return the response and the body generator. + + The request is finished (in routers) when the generator is exhausted or there is an error. + """ + ... + + @abstractmethod + async def collect_metrics(self) -> Dict[str, Any]: ... + + @abstractmethod + async def check_ready(self) -> Tuple[List[str], List[str]]: + """Return the list of ready servers and the list of unready servers.""" + ... + + async def shutdown(self) -> None: ... + + @abstractmethod + async def _finish_request(self, request: UCompletionRequest) -> None: + """Finish the request in the router.""" + ... + + +class OpenAIHttpClient(OpenAIClient): + def __init__( + self, + router: Router, + role: ServerRole, + timeout_secs: int = 180, + max_retries: int = 1, + retry_interval_sec: int = 1, + session: Optional[aiohttp.ClientSession] = None, + ): + self._router = router + self._role = role + self._metrics_collector = ClientMetricsCollector(role) + self._session = session or aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=False), + timeout=aiohttp.ClientTimeout(total=timeout_secs), + ) + self._max_retries = max_retries + self._retry_interval_sec = retry_interval_sec + + async def _send_request( + self, + endpoint: str, + request: UCompletionRequest, + response_type: Type[UCompletionResponse], + server: Optional[str] = None, + hooks: Optional[ResponseHooks] = None, + ) -> UCompletionResponseOrGenerator: + if server is None: + server, _ = await self._router.get_next_server(request) + url = f"http://{server}/{endpoint}" + logger.debug( + f"Sending {self._role} request {request.disaggregated_params.ctx_request_id} to {url}" + ) + try: + self._metrics_collector.total_requests.inc() + resp_generator = self._post_with_retry(server, url, request, hooks) + if request.stream: + # return the response generator, the request is not done yet + return resp_generator + else: + # consume the generator to get the response and return it directly when it's not streaming + response = None + async for resp_json in resp_generator: + response = response_type(**resp_json) + if hooks: + if self._role == ServerRole.CONTEXT: + hooks.on_ctx_resp(server, response) + else: + hooks.on_first_token(server, request) + hooks.on_resp_done(server, request, response) + return response + except Exception: + self._metrics_collector.error_requests.inc() + # finish the request upon error + await self._finish_request(request) + raise + + async def _post_with_retry( + self, + server: str, + url: str, + request: UCompletionRequest, + hooks: Optional[ResponseHooks] = None, + ) -> AsyncGenerator[Any, None]: + json_data = request.model_dump(exclude_unset=True) + is_stream = request.stream + for attempt in range(self._max_retries + 1): + try: + start_time = get_steady_clock_now_in_seconds() + async with self._session.post(url, json=json_data) as http_response: + content_type = http_response.headers.get("Content-Type", "") + if not is_stream and "text/event-stream" in content_type: + raise ValueError( + "Received an event-stream although request stream was False" + ) + if is_stream: + # do NOT return generator directly here or the response will go + # out of scope and get destroyed + async for line in self._response_generator( + request, http_response, start_time, server, hooks + ): + yield line + # don't finish the request here since the response generator is not done yet + else: + http_response.raise_for_status() + response_dict = await http_response.json() + # yield here since python forbids return statements in async generators + yield response_dict + # finish the request after the successful response + await self._finish_request(request) + break # break and skip retries if the whole response is processed without exception + except (aiohttp.ClientError, OSError) as e: + if attempt == self._max_retries: + logger.error( + f"Client error to {url}: {e} - last retry {attempt} of {self._max_retries}" + "failed", + traceback.format_exc(), + ) + raise + logger.error( + f"{self._role} client error to {url}: {e} - retry {attempt} of {self._max_retries}", + traceback.format_exc(), + ) + await asyncio.sleep(self._retry_interval_sec) + self._metrics_collector.retry_requests.inc() + except Exception as e: + logger.error( + f"Unexpected error while processing {self._role} request to {url}: {e}" + ) + raise + + async def _response_generator( + self, + request: UCompletionRequest, + http_response: aiohttp.ClientResponse, + start_time: float, + server: str, + hooks: Optional[ResponseHooks] = None, + ) -> AsyncGenerator[Any, None]: + assert request.stream, "Request is not streaming" + assert "text/event-stream" in http_response.headers.get("Content-Type", ""), ( + "Response is not streaming" + ) + try: + last_token_time = start_time + i = 0 + async for line in http_response.content.iter_any(): + now_time = get_steady_clock_now_in_seconds() + if i == 0: + if hooks: + hooks.on_first_token(server, request) + self._metrics_collector.first_token_latency_seconds.observe( + now_time - last_token_time + ) + else: + self._metrics_collector.per_token_latency_seconds.observe( + now_time - last_token_time + ) + i += 1 + if line: + yield line + await asyncio.sleep(0) + last_token_time = now_time + + if hooks: + hooks.on_resp_done(server, request, None) + self._metrics_collector.completed_requests.inc() + self._metrics_collector.complete_latency_seconds.observe( + get_steady_clock_now_in_seconds() - start_time + ) + except aiohttp.ClientError as e: + # a client error is expected when the response stream is done if the connector has close=True + logger.error(f"{self._role} client {server} error: {e}") + self._metrics_collector.error_requests.inc() + raise + except Exception: + self._metrics_collector.error_requests.inc() + raise + finally: + # finish the request after streaming response is done or error is raised + await self._finish_request(request) + + async def _finish_request(self, request: UCompletionRequest) -> None: + await self._router.finish_request(request) + + async def collect_metrics(self) -> Dict[str, Any]: + metrics = {} + for server in self._router.servers: + try: + async with self._session.get(f"http://{server}/perf_metrics") as response: + metrics[server] = await response.json() + except Exception: + logger.error(f"Failed to collect metrics from {server}") + continue + return metrics + + async def shutdown(self) -> None: + await self._session.close() + + async def check_ready(self) -> Tuple[List[str], List[str]]: + return await OpenAIHttpClient.check_ready_for_servers(self._session, self._router.servers) + + @staticmethod + async def check_ready_for_servers( + session: aiohttp.ClientSession, servers: List[str] + ) -> Tuple[List[str], List[str]]: + async def check_server_ready(server: str) -> bool: + try: + url = ( + f"{server}/health" + if server.startswith("http://") + else f"http://{server}/health" + ) + async with session.get(url) as response: + return response.status == 200 + except Exception: + return False + + servers_ready = await asyncio.gather(*[check_server_ready(server) for server in servers]) + return [server for server, ready in zip(servers, servers_ready) if ready], [ + server for server, ready in zip(servers, servers_ready) if not ready + ] diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 1473a1cf29c..55c3e136e5a 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -1,49 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. #!/usr/bin/env python + +# yapf: disable import asyncio -import copy -import itertools -import os import signal import traceback -from collections import deque -from collections.abc import Mapping from contextlib import asynccontextmanager -from http import HTTPStatus -from typing import Callable, Optional, Type, Union +from typing import Callable, Optional import aiohttp import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR # yapf: disable from tensorrt_llm.executor import CppExecutorError +from tensorrt_llm.executor.executor import CppExecutorError from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, MetadataServerConfig, ServerRole, - get_ctx_gen_server_urls) + get_ctx_gen_server_addrs) from tensorrt_llm.logger import logger -from tensorrt_llm.serve.cluster_storage import (WatchEventType, +from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer, create_cluster_storage) -from tensorrt_llm.serve.disagg_auto_scaling import (DisaggClusterManager, - WorkerInfo) from tensorrt_llm.serve.metadata_server import create_metadata_server -from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, - ChatCompletionResponse, - CompletionRequest, - CompletionResponse, - DisaggregatedParams, - ErrorResponse) +from tensorrt_llm.serve.openai_client import OpenAIClient, OpenAIHttpClient +from tensorrt_llm.serve.openai_disagg_service import ( + OpenAIDisaggregatedService, ResponseHooks) +from tensorrt_llm.serve.openai_protocol import (UCompletionRequest, + UCompletionResponse) +from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware, get_steady_clock_now_in_seconds) -from tensorrt_llm.serve.router import KvCacheAwareRouter, create_router +from tensorrt_llm.serve.router import Router, create_router from tensorrt_llm.version import __version__ as VERSION # yapf: enale TIMEOUT_KEEP_ALIVE = 10 # seconds. +class RawRequestResponseHooks(ResponseHooks): + def __init__(self, raw_req: Request, perf_metrics_collector: DisaggPerfMetricsCollector): + self.raw_req = raw_req + self.ctx_server = "" + self.gen_server = "" + self.server_first_token_time = 0 + self.perf_metrics_collector = perf_metrics_collector + + def on_req_begin(self, request: UCompletionRequest): + ... + + def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse): + self.ctx_server = ctx_server + logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}") + + def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None): + self.gen_server = gen_server + self.server_first_token_time = get_steady_clock_now_in_seconds() + logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}") + + def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None): + if request.disaggregated_params: + ctx_req_id = request.disaggregated_params.ctx_request_id + asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time)) + + class OpenAIDisaggServer: def __init__(self, @@ -52,122 +85,46 @@ def __init__(self, server_start_timeout_secs: int = 180, metadata_server_cfg: Optional[MetadataServerConfig] = None, metrics_interval_secs: int = 0): - self.ctx_servers, self.gen_servers = get_ctx_gen_server_urls(config.server_configs) - self.metadata_server = create_metadata_server(metadata_server_cfg) - self.ctx_router = create_router( - config.ctx_router_config, self.ctx_servers, metadata_server_cfg, self.metadata_server) - self.gen_router = create_router( - config.gen_router_config, self.gen_servers, metadata_server_cfg, self.metadata_server) - self.conditional_disagg_config = config.conditional_disagg_config - self.otlp_cfg = config.otlp_config + self._config = config + self._req_timeout_secs = req_timeout_secs + self._server_start_timeout_secs = server_start_timeout_secs + self._metadata_server_cfg = metadata_server_cfg + self._metrics_interval_secs = metrics_interval_secs + + self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs) + self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg)) + self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg)) + self._metadata_server = create_metadata_server(metadata_server_cfg) + self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests) + + self._disagg_cluster_storage = create_cluster_storage(config.disagg_cluster_config.cluster_uri, config.disagg_cluster_config.cluster_name) if config.disagg_cluster_config else None + + self._service = OpenAIDisaggregatedService( + self._config, self._ctx_router, self._gen_router, self._create_client, + metadata_server=self._metadata_server, + metadata_config=self._metadata_server_cfg, + req_timeout_secs=self._req_timeout_secs, + server_start_timeout_secs=self._server_start_timeout_secs, + perf_metrics_collector=self._perf_metrics_collector, + disagg_cluster_storage=self._disagg_cluster_storage) try: - if self.otlp_cfg and self.otlp_cfg.otlp_traces_endpoint: - tracing.init_tracer("trt.llm", self.otlp_cfg.otlp_traces_endpoint) + otlp_cfg = config.otlp_config + if otlp_cfg and otlp_cfg.otlp_traces_endpoint: + tracing.init_tracer("trt.llm", otlp_cfg.otlp_traces_endpoint) logger.info( - f"Initialized OTLP tracer successfully, endpoint: {self.otlp_cfg.otlp_traces_endpoint}" + f"Initialized OTLP tracer successfully, endpoint: {otlp_cfg.otlp_traces_endpoint}" ) except Exception as e: logger.error(f"Failed to initialize OTLP tracer: {e}") - self.perf_metrics_max_requests = config.perf_metrics_max_requests - if self.perf_metrics_max_requests > 0: - # record corresponding keys of context and generation servers for perf metrics - # (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time) - self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) - self.perf_metrics_keys_lock = asyncio.Lock() - # server_url -> {ctx_request_id: perf_metrics} - self.server_perf_metrics: dict[str, dict[int, dict]] = {} - - else: - self.perf_metrics_keys = None - self.perf_metrics_keys_lock = None - self.server_perf_metrics = None - - if config.max_retries < 0: - raise ValueError(f"Max retries {config.max_retries} must be greater than or equal to 0") - self.max_retries = config.max_retries - # Metrics counters and synchronization - self._metrics = { - "ctx_total_requests": 0, - "ctx_completed_requests": 0, - "gen_total_requests": 0, - "gen_completed_requests": 0, - } - self._metrics_lock = asyncio.Lock() - self._metrics_task = None - self.metrics_interval_secs = metrics_interval_secs - - self.disagg_cluster_config = config.disagg_cluster_config - self.disagg_cluster_storage = None - self.disagg_cluster_manager = None - self._update_worker_task = None - - logger.info(f"Server max retries: {self.max_retries}") - - if self.disagg_cluster_config is None: - if (len(self.gen_servers) == 0): - raise ValueError("At least one generation server must be provided") - - if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1" and len(self.ctx_servers) == 0: - raise ValueError("At least one context server must be provided") - - if self.conditional_disagg_config is not None and \ - not isinstance(self.gen_router, KvCacheAwareRouter): - raise ValueError("Generation router must be a KvCacheAwareRouter to enable conditional disaggregation") - - if self.disagg_cluster_config and self.metadata_server: - raise ValueError("Cluster manager and metadata server cannot be used together") - - # Session will be initialized in lifespan - self.session: Optional[aiohttp.ClientSession] = None @asynccontextmanager - async def lifespan(app: FastAPI): - # Create a persistent aiohttp ClientSession - self.session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True), - timeout=aiohttp.ClientTimeout(total=req_timeout_secs)) - - if self.disagg_cluster_manager: - await self.disagg_cluster_manager.start() - await self.disagg_cluster_manager.watch_workers() - self._update_worker_task = asyncio.create_task(self._update_router_by_watch_events()) - - logger.info("Waiting for context and generation servers to be ready") - await self.wait_for_servers_ready(server_start_timeout_secs) - - if self.perf_metrics_max_requests > 0: - await self.set_steady_clock_offsets(self.session) - - if self.metadata_server: - logger.info("Starting server monitoring via metadata service") - await self.ctx_router.start_server_monitoring(metadata_server_cfg.refresh_interval) - await self.gen_router.start_server_monitoring(metadata_server_cfg.refresh_interval) - - # Start periodic metrics logging - if self.metrics_interval_secs > 0: - self._metrics_task = asyncio.create_task(self._log_metrics_periodically(self.metrics_interval_secs)) - + async def lifespan(app) -> None: + await self._service.setup() + await self._set_steady_clock_offsets() yield - - if self.metadata_server: - logger.info("Stopping server monitoring via metadata service") - await self.ctx_router.stop_server_monitoring() - await self.gen_router.stop_server_monitoring() - - # Stop periodic metrics logging - if self._metrics_task is not None: - self._metrics_task.cancel() - try: - await self._metrics_task - except asyncio.CancelledError: - pass - - await self.session.close() # Ensure session cleanup - if self.disagg_cluster_manager: - self._update_worker_task.cancel() - await self.disagg_cluster_manager.stop() + await self._service.teardown() self.app = FastAPI(lifespan=lifespan) @@ -178,411 +135,73 @@ async def validation_exception_handler(_, exc): return JSONResponse(status_code=400, content={"error": str(exc)}) self.register_routes() - if self.disagg_cluster_config: - self.disagg_cluster_storage = create_cluster_storage(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name, server=self.app) - self.disagg_cluster_manager = DisaggClusterManager(self.disagg_cluster_config, self.disagg_cluster_storage) - - async def _increment_metric(self, key: str, amount: int = 1): - if self.metrics_interval_secs > 0: - async with self._metrics_lock: - self._metrics[key] += amount - - async def _get_metrics_snapshot(self): - async with self._metrics_lock: - return dict(self._metrics) - - async def _log_metrics_periodically(self, interval_seconds: int): - try: - while True: - await asyncio.sleep(interval_seconds) - snapshot = await self._get_metrics_snapshot() - logger.info( - ( - f"[Statistics] total_context_requests={snapshot['ctx_total_requests']}, completed_context_requests={snapshot['ctx_completed_requests']}, " - f"total_generation_requests={snapshot['gen_total_requests']}, completed_generation_requests={snapshot['gen_completed_requests']}" - ) - ) - except asyncio.CancelledError: - pass - - @staticmethod - def create_error_response( - message: str, - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - raise HTTPException(status_code=500, detail=f"Internal server error {message}") + def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient: + client = OpenAIHttpClient(router, role, self._req_timeout_secs, max_retries) + self._perf_metrics_collector.add_client(client) + return client def register_routes(self): + self.app.add_api_route("/v1/completions", self._wrap_entry_point(self._service.openai_completion), methods=["POST"]) + self.app.add_api_route("/v1/chat/completions", self._wrap_entry_point(self._service.openai_chat_completion), methods=["POST"]) self.app.add_api_route("/health", self.health, methods=["GET"]) - self.app.add_api_route("/version", self.version, methods=["GET"]) - self.app.add_api_route("/perf_metrics", self.perf_metrics, methods=["GET"]) - self.app.add_api_route("/v1/completions", - self.openai_completion, - methods=["POST"]) - self.app.add_api_route("/v1/chat/completions", - self.openai_chat_completion, - methods=["POST"]) self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"]) + self.app.add_api_route("/version", self.version, methods=["GET"]) + self.app.add_api_route("/perf_metrics", self._perf_metrics_collector.get_perf_metrics, methods=["GET"]) + # import prometheus_client lazily to break the `set_prometheus_multiproc_dir` + from prometheus_client import make_asgi_app + self.app.mount("/prometheus/metrics", make_asgi_app()) + if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer): + self._disagg_cluster_storage.add_routes(self.app) + + def _wrap_entry_point(self, entry_point: Callable) -> Callable: + async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response: + try: + hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector) + response_or_generator = await entry_point(req, hooks) + if req.stream: + return StreamingResponse(content=response_or_generator, media_type="text/event-stream") + else: + return JSONResponse(content=response_or_generator.model_dump()) + except Exception as e: + self._handle_exception(e) + return wrapper - async def health(self) -> Response: - if not await self.is_ready(): - return Response(status_code=500) - return Response(status_code=200) - - async def version(self) -> JSONResponse: - ver = {"version": VERSION} - return JSONResponse(content=ver) - - async def cluster_info(self) -> JSONResponse: - if self.disagg_cluster_manager: - cluster_info = await self.disagg_cluster_manager.cluster_info() - cluster_info["is_ready"] = await self.is_ready() - return JSONResponse(content=cluster_info) - return JSONResponse(content={}) - - async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request): - async with self.perf_metrics_keys_lock: - self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_time)) - - async def perf_metrics(self) -> JSONResponse: - if self.perf_metrics_keys is None: - return JSONResponse(content=[]) - - perf_metrics = {} - exc = None - try: - for server in self.ctx_servers + self.gen_servers: - async with self.session.get(f"{server}/perf_metrics") as response: - server_perf_metrics = await response.json() - perf_metrics[server] = server_perf_metrics - except Exception as e: - # Keep the exception to raise it after saving perf metrics - exc = e - - return_metrics = [] - async with self.perf_metrics_keys_lock: - for server in perf_metrics: - server_metrics = self.server_perf_metrics.setdefault(server, {}) - for request_perf_metrics in perf_metrics[server]: - ctx_request_id = request_perf_metrics.get("ctx_request_id", None) - if ctx_request_id is None: - continue - server_metrics[ctx_request_id] = request_perf_metrics - - if len(server_metrics) > self.perf_metrics_max_requests: - # Remove oldest requests and keep at most perf_metrics_max_requests - num_remove = len(server_metrics) - self.perf_metrics_max_requests - removed_keys = list(itertools.islice(server_metrics.keys(), num_remove)) - for ctx_request_id in removed_keys: - server_metrics.pop(ctx_request_id) - if exc is not None: - raise exc - - remain_keys = [] - for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time in self.perf_metrics_keys: - gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None) - if gen_perf_metrics is None: - # generation not finished - remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time)) - continue - ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None) - return_metrics.append({ - "ctx_server": ctx_server, - "gen_server": gen_server, - "disagg_server_arrival_time": server_arrival_time, - "disagg_server_first_token_time": server_first_token_time, - "ctx_perf_metrics": ctx_perf_metrics, - "gen_perf_metrics": gen_perf_metrics}) - self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) - - return JSONResponse(content=return_metrics) - - @tracing.trace_span("llm_request") - async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response: - if not await self.is_ready(): - raise HTTPException(status_code=400, detail="Cluster is not ready") - try: - if not isinstance(req.prompt, str): - # Check if it's a list and contains integers - if type(req.prompt) is list and len(req.prompt) == 1: - req.prompt = req.prompt[0] - elif not isinstance(req.prompt, list) or not all(isinstance(x, int) for x in req.prompt): - raise ValueError("Disaggregated server currently only supports single string prompt or list of integers in request") - - return await self._send_disagg_request(req, raw_request) - - except Exception as e: - await self._handle_exception(e) - - @tracing.trace_span("llm_request") - async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response: - if not await self.is_ready(): - raise HTTPException(status_code=400, detail="Cluster is not ready") - try: - return await self._send_disagg_request(req, raw_request) - except Exception as e: - await self._handle_exception(e) - - async def _handle_exception(self, exception): + def _handle_exception(self, exception): if isinstance(exception, CppExecutorError): - logger.error(traceback.format_exc()) + logger.error("CppExecutorError: ", traceback.format_exc()) signal.raise_signal(signal.SIGINT) elif isinstance(exception, HTTPException): - raise exception # Re-raise HTTP exceptions properly + logger.error(f"HTTPException {exception.status_code} {exception.detail}: ", traceback.format_exc()) + raise exception else: - logger.error(traceback.format_exc()) + logger.error("Internal server error: ", traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}") - async def _send_context_request(self, ctx_server: str, ctx_req: Union[CompletionRequest, ChatCompletionRequest], - trace_headers: Optional[Mapping[str, str]] = None): - - ctx_req.disaggregated_params = DisaggregatedParams(request_type="context_only") - ctx_req.stream = False - ctx_req.stream_options = None - - logger.debug("Sending request to ctx server: %s", ctx_server) - await self._increment_metric("ctx_total_requests") - try: - if isinstance(ctx_req, ChatCompletionRequest): - ctx_response = await self.send_chat_request(ctx_server, ctx_req, trace_headers) - else: - assert isinstance(ctx_req, CompletionRequest) - ctx_response = await self.send_completion_request(ctx_server, ctx_req, trace_headers) - finally: - await self.ctx_router.finish_request(ctx_req) - await self._increment_metric("ctx_completed_requests") - - choices = ctx_response.choices - if len(choices) > 1: - raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.") - if choices[0].disaggregated_params is None: - raise ValueError("Context server did not return disaggregated params") - if choices[0].disaggregated_params.ctx_request_id is None: - raise ValueError("Invalid disaggregated params in context phase response.") - - return ctx_response - - async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletionRequest], raw_request: Request): - ctx_server = None - gen_server = None - ctx_request_id = None - need_ctx = False - trace_headers = tracing.inject_trace_headers(raw_request.headers) - - async def _merge_streaming_responses(ctx_response, - gen_req: Union[CompletionRequest, ChatCompletionRequest], - trace_headers: Optional[Mapping[str, str]] = None): - try: - if ctx_response is not None and len(ctx_response.choices) != 1: - raise ValueError("Context server did not return a single choice. This is not expected") - - #If request finished after first token not due to length, return right away and skip gen - if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: - yield "data: [DONE]\n\n".encode('utf-8') - else: - # Then yield the generation responses - await self._increment_metric("gen_total_requests") - if isinstance(gen_req, CompletionRequest): - gen_response = await self.send_completion_request(gen_server, gen_req, trace_headers) - elif isinstance(gen_req, ChatCompletionRequest): - gen_response = await self.send_chat_request(gen_server, gen_req, trace_headers) - else: - raise TypeError("Invalid request type: {type(gen_req).__name__}") - - first_response = await anext(gen_response.body_iterator) - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - yield first_response - async for chunk in gen_response.body_iterator: - yield chunk - await self._increment_metric("gen_completed_requests") - if need_ctx and self.perf_metrics_keys is not None: - asyncio.create_task(self._add_perf_metrics_keys( - ctx_server, gen_server, ctx_request_id, raw_request)) + async def health(self) -> Response: + if not await self._service.is_ready(): + return Response(status_code=500) + return Response(status_code=200) - finally: - await self.gen_router.finish_request(gen_req) - try: - # Determine if need context server - condition = self.conditional_disagg_config - if condition is not None: - assert isinstance(self.gen_router, KvCacheAwareRouter) - # Query kv cache status and select a best gen_server. - # The server is reserved for generation request - gen_server, info = await self.gen_router.get_next_server(req) - match_length = sum(info["matches"]) - total_length = sum(len(token_list) for token_list in info["token_lists"]) - if match_length == 0 or total_length - match_length > condition.max_local_prefill_length: - need_ctx = True - elif os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1": - # Hard-code first token, ctx_request_id for testing - req.disaggregated_params = DisaggregatedParams( - request_type="generation_only", - first_gen_tokens=[7], - ctx_request_id=1, - encoded_opaque_state=None, - draft_tokens=None) - # Since KV cache for prompt tokens will be uninitialized, need to ignore eos - req.ignore_eos = True - else: - need_ctx = True - - if need_ctx: - ctx_req = copy.deepcopy(req) - ctx_server, _ = await self.ctx_router.get_next_server(ctx_req) - - tracing.add_event(tracing.SpanEvents.CTX_SERVER_SELECTED, attributes={"server": str(ctx_server),}) - - # TODO: add ctx_server info into generation request for pre-registration - ctx_response = await self._send_context_request(ctx_server, ctx_req, trace_headers) - - if ctx_response is not None and len(ctx_response.choices) != 1: - raise ValueError("Context server did not return a single choice. This is not expected") - - # Append disaggregates parameters to generation request - req.disaggregated_params = ctx_response.choices[0].disaggregated_params - req.disaggregated_params.request_type = "generation_only" - ctx_request_id = req.disaggregated_params.ctx_request_id - - # Replace the string prompt with prompt_tokens_ids - if isinstance(req, CompletionRequest): - req.prompt = ctx_response.prompt_token_ids - elif isinstance(req, ChatCompletionRequest): - req.prompt_token_ids = ctx_response.prompt_token_ids - else: - raise ValueError("Invalid request type: {type(req).__name__}") - else: - ctx_response = None - - # Pick a generation server if haven't reserved one, and send request - if gen_server is None: - gen_server, _ = await self.gen_router.get_next_server(req) - logger.debug("Sending request to gen server: %s", gen_server) - tracing.add_event(tracing.SpanEvents.GEN_SERVER_SELECTED,attributes={"server": str(gen_server),}) - - if not req.stream: - try: - #If request finished after first token for reason other than length, return right away and skip gen - if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length","not_finished"]: - del ctx_response.choices[0].disaggregated_params - return ctx_response - else: - await self._increment_metric("gen_total_requests") - if isinstance(req, CompletionRequest): - gen_response = await self.send_completion_request(gen_server, req, trace_headers) - else: - assert isinstance(req, ChatCompletionRequest) - gen_response = await self.send_chat_request(gen_server, req, trace_headers) - await self._increment_metric("gen_completed_requests") - if need_ctx and self.perf_metrics_keys is not None: - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - asyncio.create_task(self._add_perf_metrics_keys( - ctx_server, gen_server, ctx_request_id, raw_request)) - return gen_response - finally: - if gen_server is not None: - await self.gen_router.finish_request(req) - - else: - # Return a streaming response that combines both context and generation responses - return StreamingResponse( - _merge_streaming_responses(ctx_response, req, trace_headers), - media_type="text/event-stream" - ) - except: - if gen_server is not None: - await self.gen_router.finish_request(req) - raise + async def cluster_info(self) -> JSONResponse: + return JSONResponse(content=await self._service.cluster_info()) + async def version(self) -> JSONResponse: + return JSONResponse(content={"version": VERSION}) - async def __call__(self, host, port): + async def __call__(self, host: str, port: int): config = uvicorn.Config(self.app, host=host, port=port, - log_level="info", + log_level=logger.level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE) await uvicorn.Server(config).serve() - async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest], - end_point: str, trace_headers: Optional[Mapping[str, str]] = None): - # Prepare headers - headers = {"Content-Type": "application/json"} - if trace_headers: - headers.update(trace_headers) - - async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True), - headers=headers) as response: - content_type = response.headers.get("Content-Type", "") - if "text/event-stream" in content_type: - if not request.stream: - raise ValueError("Received an event-stream although request stream was False") - - try: - async for line in response.content.iter_any(): - if line: - yield line - await asyncio.sleep(0) - except Exception as e: - logger.error(f"Unexpected error in stream: {e}") - raise - - async def create_completion_generator(self, url: str, request: CompletionRequest, - trace_headers: Optional[Mapping[str, str]] = None): - async for chunk in self.create_generator(url, request, "/v1/completions", trace_headers): - yield chunk - - async def create_chat_generator(self, url: str, request: ChatCompletionRequest, - trace_headers: Optional[Mapping[str, str]] = None): - async for chunk in self.create_generator(url, request, "/v1/chat/completions", trace_headers): - yield chunk - - async def send_request(self, url: str, - request: Union[CompletionRequest, ChatCompletionRequest], - endpoint: str, - response_type: Type[Union[CompletionResponse, ChatCompletionResponse]], - create_generator: Callable, - trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]: - for attempt in range(self.max_retries + 1): - try: - headers = {"Content-Type": "application/json"} - if trace_headers: - headers.update(trace_headers) - if request.stream: - response_generator = create_generator(url, request, headers) - return StreamingResponse(content=response_generator, media_type="text/event-stream") - else: - async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True), - headers=headers) as response: - content_type = response.headers.get("Content-Type", "") - if "text/event-stream" in content_type: - raise ValueError("Received an event-stream although request stream was False") - - response_dict = await response.json() - if not response.ok: - logger.error(f"Received failed response {response_dict}") - response.raise_for_status() - return response_type(**response_dict) - except (aiohttp.ClientError, OSError) as e: - if attempt == self.max_retries: - raise HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Internal server error") from e - logger.error(f"Client error: {e} - retry {attempt} of {self.max_retries}") - # TODO : add a configurable retry interval - await asyncio.sleep(1) - except Exception as e: - logger.error(f"Error encountered while processing request to {url+endpoint}: {e}") - raise - - async def send_completion_request(self, url: str, request: CompletionRequest, - trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, StreamingResponse]: - return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator, trace_headers) - - async def send_chat_request(self, url: str, request: ChatCompletionRequest, - trace_headers: Optional[Mapping[str, str]] = None) -> ChatCompletionResponse: - return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator, trace_headers) - - async def set_steady_clock_offsets(self, session: aiohttp.ClientSession): + # TODO: rework this for service discovery, now it's only for static server list + async def _set_steady_clock_offsets(self): STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset" - async def query_steady_clock_offset(server_url: str) -> tuple[Optional[float], Optional[float]]: + async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> tuple[Optional[float], Optional[float]]: try: originate_ts = get_steady_clock_now_in_seconds() async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response: @@ -599,79 +218,24 @@ async def query_steady_clock_offset(server_url: str) -> tuple[Optional[float], O return None, None except Exception: return None, None - async def set_steady_clock_offset(server_url: str, offset: float) -> None: + + async def set_steady_clock_offset(session: aiohttp.ClientSession, server_url: str, offset: float) -> None: payload = {"offset": offset} async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response: if response.status != 200: logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned") - for server_url in self.ctx_servers + self.gen_servers: - delay, offset = await query_steady_clock_offset(server_url) + + async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> None: + server_url = f"http://{server_url}" if not server_url.startswith("http://") else server_url + delay, offset = await query_steady_clock_offset(session, server_url) if delay is None or offset is None: logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment") - continue + return logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second') # Negate the offset so that worker servers can adjust their steady clock by adding the new offset - await set_steady_clock_offset(server_url, -offset) + await set_steady_clock_offset(session, server_url, -offset) - @classmethod - async def check_server_ready(cls, session: aiohttp.ClientSession, server_url: str) -> bool: - try: - async with session.get(server_url+"/health") as response: - return response.status == 200 - except Exception: - return False - - @classmethod - async def wait_for_all_servers_ready(cls, session: aiohttp.ClientSession, - ctx_servers: list[str], - gen_servers: list[str], - server_start_timeout_secs: int = 180): - async def get_unready_servers(servers: list[str]) -> list[str]: - servers_ready = await asyncio.gather(*[cls.check_server_ready(session, server) for server in servers]) - return [server for server, ready in zip(servers, servers_ready) if not ready] - - async def check_all_servers_ready(): - iter = 0 - unready_servers = await get_unready_servers(ctx_servers + gen_servers) - while len(unready_servers) > 0: - wait_time = 3 - logger.info( - f"[{iter}] Servers are not ready. Waiting for {unready_servers}..." - ) - await asyncio.sleep(wait_time) - iter += 1 - unready_servers = await get_unready_servers(unready_servers) - try: - await asyncio.wait_for(check_all_servers_ready(), timeout=server_start_timeout_secs) - except asyncio.CancelledError: - raise TimeoutError("Timeout waiting for context and generation servers to be ready") - logger.info("Context and generation servers are ready") - - async def is_ready(self) -> bool: - if self.disagg_cluster_manager: - return await self.disagg_cluster_manager.is_ready_with_router(len(self.ctx_router.servers), len(self.gen_router.servers)) - return True - - async def wait_for_servers_ready(self, server_start_timeout_secs: int = 180): - await self.wait_for_all_servers_ready(self.session, self.ctx_servers, self.gen_servers, server_start_timeout_secs) - - async def _update_router_by_watch_events(self): - def worker_repr(worker_info: WorkerInfo): - return f"http://{worker_info.host}:{worker_info.port}" - router_map = { - ServerRole.CONTEXT: self.ctx_router, - ServerRole.GENERATION: self.gen_router - } - logger.info("Start updating routers by worker events") - while True: - try: - worker_events = await self.disagg_cluster_manager.get_worker_events() - for worker_info, event_type in worker_events: - if event_type == WatchEventType.SET: - await router_map[worker_info.role].add_server(worker_repr(worker_info)) - elif event_type == WatchEventType.DELETE: - await router_map[worker_info.role].remove_server(worker_repr(worker_info)) - logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}") - except Exception as e: - logger.error(f"Error updating routers by worker events: {e}") - await asyncio.sleep(1) + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True), + timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session: + await asyncio.gather(*[align_steady_clock_offset(session, server_url) for server_url in self._ctx_servers + self._gen_servers]) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py new file mode 100644 index 00000000000..d1f8d8dad7f --- /dev/null +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -0,0 +1,298 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import os +from typing import Any, Callable, Dict, Optional + +from tensorrt_llm.llmapi.disagg_utils import ( + ConditionalDisaggConfig, + DisaggClusterConfig, + DisaggServerConfig, + MetadataServerConfig, + ServerRole, +) +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType +from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo +from tensorrt_llm.serve.metadata_server import JsonDictionary +from tensorrt_llm.serve.openai_client import OpenAIClient +from tensorrt_llm.serve.openai_protocol import ( + ChatCompletionRequest, + CompletionRequest, + DisaggregatedParams, + UCompletionRequest, + UCompletionResponse, +) +from tensorrt_llm.serve.openai_service import OpenAIService +from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector +from tensorrt_llm.serve.responses_utils import ( + ResponseHooks, + UCompletionResponseOrGenerator, + done_generator, +) +from tensorrt_llm.serve.router import KvCacheAwareRouter, Router + + +class OpenAIDisaggregatedService(OpenAIService): + def __init__( + self, + config: DisaggServerConfig, + ctx_router: Router, + gen_router: Router, + client_factory: Callable[[Router, ServerRole], OpenAIClient], + metadata_server: Optional[JsonDictionary] = None, + metadata_config: Optional[MetadataServerConfig] = None, + req_timeout_secs: int = 180, + server_start_timeout_secs: int = 180, + perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None, + disagg_cluster_storage: Optional[ClusterStorage] = None, + health_check_interval_secs: int = 3, + ): + self._config = config + self._ctx_router = ctx_router + self._gen_router = gen_router + self._client_factory = client_factory + self._metadata_server = metadata_server + self._metadata_config = metadata_config + self._req_timeout_secs = req_timeout_secs + self._server_start_timeout_secs = server_start_timeout_secs + self._perf_metrics_collector = perf_metrics_collector + self._cluster_storage = disagg_cluster_storage + self._health_check_interval_secs = health_check_interval_secs + + self._ctx_client = None + self._gen_client = None + self._disagg_cluster_manager = None + + async def openai_completion( + self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None + ) -> UCompletionResponseOrGenerator: + if not await self.is_ready(): + raise RuntimeError("Cluster is not ready") + if not isinstance(request.prompt, str): + # Check if it's a list and contains integers + if type(request.prompt) is list and len(request.prompt) == 1: + request.prompt = request.prompt[0] + elif not isinstance(request.prompt, list) or not all( + isinstance(x, int) for x in request.prompt + ): + raise ValueError( + "Disaggregated server currently only supports single string prompt or list of integers in request" + ) + + return await self._send_disagg_request(request, hooks) + + async def openai_chat_completion( + self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None + ) -> UCompletionResponseOrGenerator: + if not await self.is_ready(): + raise RuntimeError("Cluster is not ready") + return await self._send_disagg_request(request, hooks) + + async def _send_disagg_request( + self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None + ) -> UCompletionResponseOrGenerator: + if hooks: + hooks.on_req_begin(request) + # empty server means client decides which server to use + reserved_gen_server = None + reserved_ctx_server = None + # reserve a gen_server if conditional disagg is needed + reserved_gen_server, need_ctx = await self._check_conditional_disagg(request) + need_ctx = need_ctx and not await self._check_gen_only_disagg(request) + ctx_response = None + gen_req = request + if need_ctx: + ctx_req = self._get_ctx_request(request) + # ctx generator is empty + ctx_response = await self._ctx_client.send_request( + ctx_req, server=reserved_ctx_server, hooks=hooks + ) + await self._verify_ctx_response(ctx_response) + gen_req = self._get_gen_request(request, ctx_response) + if ctx_response is None or self._need_gen(ctx_response): + return await self._gen_client.send_request( + gen_req, server=reserved_gen_server, hooks=hooks + ) + else: + if request.stream: + # ctx client will never return a generator when streaming is requested + # make up for this by returning a done generator + return done_generator() + return ctx_response + + def _need_gen(self, response: UCompletionResponse) -> bool: + if response and response.choices[0].finish_reason not in ["length", "not_finished"]: + del response.choices[0].disaggregated_params + return False + return True + + def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest: + ctx_request = copy.deepcopy(request) + ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only") + ctx_request.stream = False + ctx_request.stream_options = None + return ctx_request + + def _get_gen_request( + self, + request: UCompletionRequest, + ctx_response: UCompletionResponse, + ) -> UCompletionRequest: + request.disaggregated_params = ctx_response.choices[0].disaggregated_params + request.disaggregated_params.request_type = "generation_only" + # Replace the string prompt with prompt_tokens_ids + if isinstance(request, CompletionRequest): + request.prompt = ctx_response.prompt_token_ids + elif isinstance(request, ChatCompletionRequest): + request.prompt_token_ids = ctx_response.prompt_token_ids + return request + + async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool: + if self.conditional_disagg_config: + assert isinstance(self._gen_router, KvCacheAwareRouter) + # Query kv cache status and select a best gen_server. + # The server is reserved for generation request + gen_server, info = await self._gen_router.get_next_server(request) + match_length = sum(info["matches"]) + total_length = sum(len(token_list) for token_list in info["token_lists"]) + if ( + match_length == 0 + or total_length - match_length + > self.conditional_disagg_config.max_local_prefill_length + ): + return gen_server, True + return gen_server, False + return None, True + + async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool: + if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1": + # Hard-code first token, ctx_request_id for testing + request.disaggregated_params = DisaggregatedParams( + request_type="generation_only", + first_gen_tokens=[7], + ctx_request_id=1, + encoded_opaque_state=None, + draft_tokens=None, + ) + request.ignore_eos = True + return True + return False + + async def cluster_info(self) -> Dict[str, Any]: + cluster_info = {"is_ready": await self.is_ready()} + if self._disagg_cluster_manager: + cluster_info.update(await self._disagg_cluster_manager.cluster_info()) + return cluster_info + + async def is_ready(self) -> bool: + if self._disagg_cluster_manager: + return await self._disagg_cluster_manager.is_ready() + return True + + @property + def disagg_cluster_config(self) -> Optional[DisaggClusterConfig]: + return self._config.disagg_cluster_config + + @property + def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]: + return self._config.conditional_disagg_config + + async def setup(self) -> None: + self._ctx_client = self._client_factory( + self._ctx_router, ServerRole.CONTEXT, self._config.max_retries + ) + self._gen_client = self._client_factory( + self._gen_router, ServerRole.GENERATION, self._config.max_retries + ) + + if self.disagg_cluster_config and self._cluster_storage: + logger.info("Starting disagg cluster manager") + self._disagg_cluster_manager = DisaggClusterManager( + self.disagg_cluster_config, self._cluster_storage + ) + await self._disagg_cluster_manager.start() + await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event) + logger.info("Disagg cluster manager started") + else: + if self._metadata_server and self._metadata_config: + logger.info("Starting server monitoring via metadata service") + await self._ctx_router.start_server_monitoring( + self._metadata_config.refresh_interval + ) + await self._gen_router.start_server_monitoring( + self._metadata_config.refresh_interval + ) + await self._wait_for_all_servers_ready() + + async def teardown(self) -> None: + await self._ctx_client.shutdown() + await self._gen_client.shutdown() + + if self._disagg_cluster_manager: + await self._disagg_cluster_manager.stop() + + if self._metadata_server: + await self._ctx_router.stop_server_monitoring() + await self._gen_router.stop_server_monitoring() + + async def _wait_for_all_servers_ready(self) -> None: + async def check_servers_ready(): + elapsed_time = 0 + interval = self._health_check_interval_secs + while elapsed_time < self._server_start_timeout_secs: + _, unready_ctx_servers = await self._ctx_client.check_ready() + _, unready_gen_servers = await self._gen_client.check_ready() + if len(unready_ctx_servers) == 0 and len(unready_gen_servers) == 0: + logger.info("All servers are ready") + return + logger.info( + f"Waiting for servers, context: {unready_ctx_servers}, generation: {unready_gen_servers}" + ) + await asyncio.sleep(interval) + elapsed_time += interval + + try: + await asyncio.wait_for(check_servers_ready(), timeout=self._server_start_timeout_secs) + except asyncio.TimeoutError: + raise TimeoutError("Timeout waiting for context and generation servers to be ready") + + async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType): + router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router} + worker_addr = f"{worker_info.host}:{worker_info.port}" + try: + router = router_map[worker_info.role] + if event_type == WatchEventType.SET: + await router.add_server(worker_addr) + elif event_type == WatchEventType.DELETE: + await router.remove_server(worker_addr) + logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}") + except KeyError: + logger.error( + f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}" + ) + + async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None: + if ctx_response: + if len(ctx_response.choices) != 1: + raise ValueError( + f"Context server returned {len(ctx_response.choices)} choices, expecting 1." + ) + if ctx_response.choices[0].disaggregated_params is None: + raise ValueError("Context server did not return disaggregated params") + if ctx_response.choices[0].disaggregated_params.ctx_request_id is None: + raise ValueError("Invalid disaggregated params in context phase response.") + return ctx_response diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index af8111d1f07..283f4c74e94 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -955,3 +955,7 @@ def to_llm_disaggregated_params( opaque_state=decode_opaque_state( disaggregated_params.encoded_opaque_state), draft_tokens=disaggregated_params.draft_tokens) + + +UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest] +UCompletionResponse = Union[CompletionResponse, ChatCompletionResponse] diff --git a/tensorrt_llm/serve/openai_service.py b/tensorrt_llm/serve/openai_service.py new file mode 100644 index 00000000000..e5e6e2047f6 --- /dev/null +++ b/tensorrt_llm/serve/openai_service.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# yapf: disable +from abc import ABC, abstractmethod + +from tensorrt_llm.serve.openai_protocol import ChatCompletionRequest, CompletionRequest +from tensorrt_llm.serve.responses_utils import UCompletionResponseOrGenerator + +# yapf: enable + + +class OpenAIService(ABC): + @abstractmethod + async def openai_completion(self, request: CompletionRequest) -> UCompletionResponseOrGenerator: + """Return either a completion response or an async completion response generator. + + When request is streaming, the generator will be used to stream the response. + When request is not streaming, the response will be returned directly. + """ + ... + + @abstractmethod + async def openai_chat_completion( + self, request: ChatCompletionRequest + ) -> UCompletionResponseOrGenerator: + """Similar to openai_completion, but for chat completion protocol.""" + ... + + @abstractmethod + async def is_ready(self) -> bool: + """Check if the service is ready to accept requests.""" + ... + + @abstractmethod + async def setup(self) -> None: + """Setup the service.""" + ... + + @abstractmethod + async def teardown(self) -> None: + """Teardown the service.""" + ... diff --git a/tensorrt_llm/serve/perf_metrics.py b/tensorrt_llm/serve/perf_metrics.py new file mode 100644 index 00000000000..60b65179eaa --- /dev/null +++ b/tensorrt_llm/serve/perf_metrics.py @@ -0,0 +1,209 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Union + +from tensorrt_llm.llmapi.disagg_utils import ServerRole + +COUNTER_METRICS = [ + ("total_requests", "Total number of requests"), + ("error_requests", "Total number of error requests"), + ("retry_requests", "Total number of retry requests"), + ("completed_requests", "Total number of completed requests"), +] +# fmt: off +LONG_TIME_BUCKETS = [ + 0.1, 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, + 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, +] +SHORT_TIME_BUCKETS = [ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, + 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0, +] +# fmt: on +HISTOGRAM_METRICS = [ + ( + "first_token_latency_seconds", + "Histogram of latency from first token to completion in seconds", + SHORT_TIME_BUCKETS, + ), + ( + "complete_latency_seconds", + "Histogram of latency from request arrival to last token in seconds", + LONG_TIME_BUCKETS, + ), + ( + "per_token_latency_seconds", + "Histogram of latency from request arrival to completion in seconds", + SHORT_TIME_BUCKETS, + ), +] + +MetricsTypeLiteral = Literal["counter", "histogram"] + + +@dataclass +class MetricsDefinition: + name: str + description: str + type: MetricsTypeLiteral + buckets: Optional[List[float]] = None + + +METRICS_DEFINITIONS = [ + MetricsDefinition("total_requests", "Total number of requests", "counter"), + MetricsDefinition("error_requests", "Total number of error requests", "counter"), + MetricsDefinition("retry_requests", "Total number of retry requests", "counter"), + MetricsDefinition("completed_requests", "Total number of completed requests", "counter"), + MetricsDefinition( + "first_token_latency_seconds", + "Histogram of latency from first token to completion in seconds", + "histogram", + SHORT_TIME_BUCKETS, + ), + MetricsDefinition( + "complete_latency_seconds", + "Histogram of latency from request arrival to last token in seconds", + "histogram", + LONG_TIME_BUCKETS, + ), + MetricsDefinition( + "per_token_latency_seconds", + "Histogram of latency from request arrival to completion in seconds", + "histogram", + SHORT_TIME_BUCKETS, + ), +] + +ROLE_TO_CLIENT_TYPE = { + ServerRole.CONTEXT: "ctx", + ServerRole.GENERATION: "gen", + ServerRole.MM_ENCODER: "mme", +} + + +class ClientMetricsCollector: + def __init__(self, role: ServerRole): + self._role = role + # import lazily to avoid breaking `set_prometheus_multiproc_dir` + from prometheus_client import Counter, Histogram + + def instance_metric(definition: MetricsDefinition) -> Union[Counter | Histogram]: + name = f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}" + if definition.type == "counter": + return Counter(name, definition.description) + elif definition.type == "histogram": + return Histogram(name, definition.description, buckets=definition.buckets) + else: + raise ValueError(f"Invalid metric type: {definition.type}") + + self._metrics = { + definition.name: instance_metric(definition) for definition in METRICS_DEFINITIONS + } + + def __getattr__( + self, key: str + ): # no return type hint to not import prometheus_client at module level + return self._metrics[key] + + +class DisaggPerfMetricsCollector: + def __init__(self, max_requests: int): + self._max_requests = max_requests + self._request_meteics = deque(maxlen=max_requests) + self._server_metrics = defaultdict(dict) + self._lock = asyncio.Lock() + self._clients = [] + + def add_client(self, client): + self._clients.append(client) + + async def add_per_request_metrics( + self, + ctx_server: str, + gen_server: str, + ctx_request_id: int, + server_arrival_time: float, + server_first_token_time: float, + ): + async with self._lock: + self._request_meteics.append( + ( + ctx_server, + gen_server, + ctx_request_id, + server_arrival_time, + server_first_token_time, + ) + ) + + async def get_perf_metrics(self) -> List[Dict[str, Any]]: + perf_metrics = {} + for client in self._clients: + metrics_dict = await client.collect_metrics() + perf_metrics.update(metrics_dict) + + return_metrics = [] + async with self._lock: + for server, metrics_data in perf_metrics.items(): + server_metrics = self._server_metrics[server] + # avoid metrics map inflation by limiting the number of requests to add + available_req_num = min( + max(0, self._max_requests - len(server_metrics)), len(metrics_data) + ) + req_metrics_map = { + req_metrics["ctx_request_id"]: req_metrics + for req_metrics in metrics_data[:available_req_num] + if "ctx_request_id" in req_metrics + } + server_metrics.update(req_metrics_map) + + remain_keys = [] + for ( + ctx_server, + gen_server, + ctx_request_id, + server_arrival_time, + server_first_token_time, + ) in self._request_meteics: + gen_perf_metrics = self._server_metrics[gen_server].pop(ctx_request_id, None) + if gen_perf_metrics is None: + # generation not finished + remain_keys.append( + ( + ctx_server, + gen_server, + ctx_request_id, + server_arrival_time, + server_first_token_time, + ) + ) + continue + ctx_perf_metrics = self._server_metrics[ctx_server].pop(ctx_request_id, None) + # TODO: strip the keys for less repeating and use table style response + return_metrics.append( + { + "ctx_server": ctx_server, + "gen_server": gen_server, + "disagg_server_arrival_time": server_arrival_time, + "disagg_server_first_token_time": server_first_token_time, + "ctx_perf_metrics": ctx_perf_metrics, + "gen_perf_metrics": gen_perf_metrics, + } + ) + self._request_meteics = deque(remain_keys, maxlen=self._max_requests) + return return_metrics diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index ab8fdae47b5..18e26f735db 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -6,11 +6,12 @@ import os import time import uuid +# yapf: disable +from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from copy import copy -from typing import Literal, Optional, OrderedDict, Union +from typing import Any, Literal, Optional, OrderedDict, Union -# yapf: disable from openai.types.responses import (ResponseCompletedEvent, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, @@ -25,7 +26,6 @@ ResponseReasoningTextDoneEvent, ResponseTextDeltaEvent, ResponseTextDoneEvent) -# yapf: enable from openai.types.responses.response_function_web_search import ( ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) from openai.types.responses.response_reasoning_item import Content @@ -42,10 +42,14 @@ from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel, ResponseInputOutputItem, ResponsesRequest, - ResponsesResponse) + ResponsesResponse, + UCompletionRequest, + UCompletionResponse) from .harmony_adapter import HarmonyAdapter +# yapf: enable + REASONING_EFFORT = { "high": ReasoningEffort.HIGH, "medium": ReasoningEffort.MEDIUM, @@ -883,3 +887,39 @@ async def __call__(self, scope, receive, send): # Pass through the original receive/send - no wrapping! await self.app(scope, receive, send) + + +class ResponseHooks(ABC): + """ + Hooks for response processing and (disagg) service perf observability. + """ + + @abstractmethod + def on_req_begin(self, request: UCompletionRequest): + pass + + @abstractmethod + def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse): + pass + + @abstractmethod + def on_first_token(self, + gen_server: str, + request: UCompletionRequest, + response: UCompletionResponse = None): + pass + + @abstractmethod + def on_resp_done(self, + gen_server: str, + request: UCompletionRequest, + response: UCompletionResponse = None): + pass + + +async def done_generator() -> AsyncGenerator[bytes, None]: + yield "data: [DONE]\n\n".encode('utf-8') + + +UCompletionResponseOrGenerator = Union[UCompletionResponse, + AsyncGenerator[Any, None]] diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 720da1acbdc..253b39a00cf 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -402,6 +402,14 @@ def run_client_tests(example_dir, assert not_expected_string not in content, f"Unexpected string '{not_expected_string}' found in {output_file}" +# TODO: add test for disaggregated server prometheus metrics +def fetch_prometheus_metrics(server_url: str): + import requests + response = requests.get(f"{server_url}/prometheus/metrics", timeout=10) + assert response.status_code == 200 + return response.text + + def run_disaggregated_test(example_dir, test_desc, num_iters=5, diff --git a/tests/integration/defs/disaggregated/test_workers.py b/tests/integration/defs/disaggregated/test_workers.py index 786d34d0e6c..c4fb51f63db 100644 --- a/tests/integration/defs/disaggregated/test_workers.py +++ b/tests/integration/defs/disaggregated/test_workers.py @@ -14,7 +14,7 @@ from transformers import AutoTokenizer from tensorrt_llm import logger -from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer +from tensorrt_llm.serve.openai_client import OpenAIHttpClient from tensorrt_llm.serve.openai_protocol import (CompletionRequest, DisaggregatedParams) from tensorrt_llm.serve.router import (KvCacheAwareRouter, @@ -66,6 +66,34 @@ def run_disaggregated_workers( DEFAULT_TIMEOUT_REQUEST = 180 +async def wait_until_all_servers_ready( + session: aiohttp.ClientSession, + servers: List[str], + server_start_timeout_secs: int = 180, +) -> None: + + async def check_all_servers_ready(): + elapsed_time = 0 + interval = 3 + while elapsed_time < server_start_timeout_secs: + _, unready_servers = await OpenAIHttpClient.check_ready_for_servers( + session, servers) + if len(unready_servers) == 0: + return + await asyncio.sleep(interval) + elapsed_time += interval + logger.info( + f"[{elapsed_time}] Waiting for servers, {unready_servers}...") + + try: + await asyncio.wait_for(check_all_servers_ready(), + timeout=server_start_timeout_secs) + except asyncio.TimeoutError: + raise TimeoutError( + f"Timeout waiting for all servers to be ready in {server_start_timeout_secs} seconds" + ) + + class BasicWorkerTester: def __init__(self, @@ -82,9 +110,9 @@ async def new_session(self): session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(force_close=True), timeout=aiohttp.ClientTimeout(total=self.req_timeout_secs)) - await OpenAIDisaggServer.wait_for_all_servers_ready( - session, self.ctx_servers, self.gen_servers, - self.server_start_timeout_secs) + await wait_until_all_servers_ready(session, + self.ctx_servers + self.gen_servers, + self.server_start_timeout_secs) return session async def send_request(self, session: aiohttp.ClientSession, url: str, diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 5fc56bd938f..4aa41e577a4 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -23,6 +23,7 @@ l0_a10: # test list either). - unittest/_torch/models/checkpoints/hf/test_weight_loader.py - unittest/others/test_time_breakdown.py + - unittest/disaggregated/test_disagg_openai_client.py - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py diff --git a/tests/unittest/disaggregated/test_disagg_openai_client.py b/tests/unittest/disaggregated/test_disagg_openai_client.py new file mode 100644 index 00000000000..698344da031 --- /dev/null +++ b/tests/unittest/disaggregated/test_disagg_openai_client.py @@ -0,0 +1,271 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp +import pytest + +from tensorrt_llm.llmapi.disagg_utils import ServerRole +from tensorrt_llm.serve.openai_client import OpenAIHttpClient +from tensorrt_llm.serve.openai_protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DisaggregatedParams, + UsageInfo, +) +from tensorrt_llm.serve.router import Router + + +@pytest.fixture +def mock_router(): + """Create a mock router.""" + router = AsyncMock(spec=Router) + router.servers = ["localhost:8000", "localhost:8001"] + router.get_next_server = AsyncMock(return_value=("localhost:8000", None)) + router.finish_request = AsyncMock() + return router + + +@pytest.fixture +def mock_session(): + """Create a mock aiohttp session.""" + return AsyncMock(spec=aiohttp.ClientSession) + + +@pytest.fixture +def openai_client(mock_router, mock_session): + """Create an OpenAIHttpClient instance.""" + # uninitialize the prometheus metrics collector or it will raise a duplicate metric error + from prometheus_client.registry import REGISTRY + + REGISTRY._names_to_collectors = {} + REGISTRY._collector_to_names = {} + return OpenAIHttpClient( + router=mock_router, + role=ServerRole.CONTEXT, + timeout_secs=180, + max_retries=2, + retry_interval_sec=1, + session=mock_session, + ) + + +@pytest.fixture +def completion_request(): + """Create a sample non-streaming CompletionRequest.""" + return CompletionRequest( + model="test-model", + prompt="Hello, world!", + stream=False, + disaggregated_params=DisaggregatedParams( + request_type="generation_only", first_gen_tokens=[123], ctx_request_id=123 + ), + ) + + +@pytest.fixture +def streaming_completion_request(): + """Create a sample streaming CompletionRequest.""" + return CompletionRequest( + model="test-model", + prompt="Hello, world!", + stream=True, + disaggregated_params=DisaggregatedParams( + request_type="generation_only", first_gen_tokens=[456], ctx_request_id=456 + ), + ) + + +class TestOpenAIHttpClient: + """Test OpenAIHttpClient main functionality.""" + + def dummy_response(self): + return CompletionResponse( + id="test-123", + object="text_completion", + created=1234567890, + model="test-model", + usage=UsageInfo(prompt_tokens=10, completion_tokens=10), + choices=[CompletionResponseChoice(index=0, text="Hello!")], + ) + + def test_initialization(self, mock_router, mock_session): + """Test client initialization.""" + client = OpenAIHttpClient( + router=mock_router, + role=ServerRole.GENERATION, + timeout_secs=300, + max_retries=5, + session=mock_session, + ) + assert client._router == mock_router + assert client._role == ServerRole.GENERATION + assert client._session == mock_session + assert client._max_retries == 5 + + @pytest.mark.asyncio + async def test_non_streaming_completion_request( + self, openai_client, completion_request, mock_session, mock_router + ): + """Test non-streaming completion request end-to-end.""" + mock_response = self.dummy_response() + + # Mock HTTP response + mock_http_response = AsyncMock() + mock_http_response.status = 200 + mock_http_response.headers = {"Content-Type": "application/json"} + mock_http_response.json = AsyncMock(return_value=mock_response.model_dump()) + mock_http_response.raise_for_status = Mock() + mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response) + mock_http_response.__aexit__ = AsyncMock() + + mock_session.post.return_value = mock_http_response + + # Send request + response = await openai_client.send_request(completion_request) + + # Assertions + assert isinstance(response, CompletionResponse) + assert response.model == "test-model" + mock_session.post.assert_called_once() + mock_router.finish_request.assert_called_once_with(completion_request) + + @pytest.mark.asyncio + async def test_streaming_completion_request( + self, openai_client, streaming_completion_request, mock_session, mock_router + ): + """Test streaming completion request end-to-end.""" + # Mock HTTP streaming response + mock_http_response = AsyncMock() + mock_http_response.status = 200 + mock_http_response.headers = {"Content-Type": "text/event-stream"} + + dummy_data = [ + b'data: "Hello"\n\n', + b'data: "world"\n\n', + b'data: "!"\n\n', + ] + + async def mock_iter_any(): + for data in dummy_data: + yield data + + mock_http_response.content = AsyncMock() + mock_http_response.content.iter_any = mock_iter_any + mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response) + mock_http_response.__aexit__ = AsyncMock() + + mock_session.post.return_value = mock_http_response + + # Send streaming request + response_generator = await openai_client.send_request(streaming_completion_request) + + # Consume the generator + chunks = [] + async for chunk in response_generator: + chunks.append(chunk) + + # Assertions + assert len(chunks) == 3 + for i, chunk in enumerate(chunks): + assert chunk == dummy_data[i] + mock_session.post.assert_called_once() + mock_router.finish_request.assert_called_once_with(streaming_completion_request) + + @pytest.mark.asyncio + async def test_request_with_custom_server( + self, openai_client, completion_request, mock_session, mock_router + ): + """Test sending request to a specific server.""" + custom_server = "localhost:9000" + mock_response = self.dummy_response() + + mock_http_response = AsyncMock() + mock_http_response.headers = {"Content-Type": "application/json"} + mock_http_response.json = AsyncMock(return_value=mock_response.model_dump()) + mock_http_response.raise_for_status = Mock() + mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response) + mock_http_response.__aexit__ = AsyncMock() + + mock_session.post.return_value = mock_http_response + + await openai_client.send_request(completion_request, server=custom_server) + + # Verify custom server was used in URL + call_args = mock_session.post.call_args[0][0] + assert custom_server in call_args + # Router should not be called when server is specified + mock_router.get_next_server.assert_not_called() + + @pytest.mark.asyncio + async def test_request_error_handling( + self, openai_client, completion_request, mock_session, mock_router + ): + """Test error handling when request fails.""" + mock_session.post.side_effect = aiohttp.ClientError("Connection failed") + + with pytest.raises(aiohttp.ClientError): + await openai_client.send_request(completion_request) + + # Should finish request on error + mock_router.finish_request.assert_called_once_with(completion_request) + + @pytest.mark.asyncio + async def test_request_with_retry( + self, openai_client, completion_request, mock_session, mock_router + ): + """Test retry mechanism on transient failures.""" + mock_response = self.dummy_response() + + mock_http_response = AsyncMock() + mock_http_response.headers = {"Content-Type": "application/json"} + mock_http_response.json = AsyncMock(return_value=mock_response.model_dump()) + mock_http_response.raise_for_status = Mock() + mock_http_response.__aenter__ = AsyncMock(return_value=mock_http_response) + mock_http_response.__aexit__ = AsyncMock() + + # First attempt fails, second succeeds + mock_session.post.side_effect = [ + aiohttp.ClientError("Temporary failure"), + mock_http_response, + ] + + with patch("asyncio.sleep", new_callable=AsyncMock): + response = await openai_client.send_request(completion_request) + + assert isinstance(response, CompletionResponse) + assert mock_session.post.call_count == 2 # Initial + 1 retry + + @pytest.mark.asyncio + async def test_max_retries_exceeded( + self, openai_client, completion_request, mock_session, mock_router + ): + """Test that request fails after max retries.""" + mock_session.post.side_effect = aiohttp.ClientError("Connection failed") + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(aiohttp.ClientError): + await openai_client.send_request(completion_request) + + # Should try max_retries + 1 times + assert mock_session.post.call_count == openai_client._max_retries + 1 + mock_router.finish_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invalid_request_type(self, openai_client): + """Test handling of invalid request type.""" + with pytest.raises(ValueError, match="Invalid request type"): + await openai_client.send_request("invalid_request")