Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorrt_llm/llmapi/disagg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ class MetadataServerConfig():
refresh_interval: float = 10.0


def get_ctx_gen_server_urls(
def get_ctx_gen_server_addrs(
server_configs: list[CtxGenServerConfig]
) -> tuple[list[str], list[str]]:
ctx_server_urls = []
gen_server_urls = []
for cfg in server_configs:
if cfg.type == "ctx":
ctx_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
ctx_server_urls.append(f"{cfg.hostname}:{cfg.port}")
else:
gen_server_urls.append(f"http://{cfg.hostname}:{cfg.port}")
gen_server_urls.append(f"{cfg.hostname}:{cfg.port}")

return ctx_server_urls, gen_server_urls

Expand Down
41 changes: 39 additions & 2 deletions tensorrt_llm/serve/disagg_auto_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import time
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple

from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
self._current_ctx_workers = {} # worker_id -> WorkerInfo
self._current_gen_workers = {} # worker_id -> WorkerInfo
self._watch_handle = None
self._watch_task = None

def __del__(self):
try:
Expand Down Expand Up @@ -92,7 +93,14 @@ def current_gen_worker_num(self) -> int:
def worker_key_prefix(self) -> str:
return get_worker_key_prefix(self._config.cluster_name)

async def watch_workers(self, get_existing_first: bool = True):
async def watch_workers(
self,
get_existing_first: bool = True,
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
Awaitable[Any]]] = None):
if self._watch_handle:
logger.error("Watch handle is already initialized")
return []
workers = []
self._watch_handle = await self._cluster_storage.watch(
self.worker_key_prefix)
Expand All @@ -109,12 +117,41 @@ async def watch_workers(self, get_existing_first: bool = True):
workers.append(self._parse_worker_info(event))
events.append(event)
await self._watch_handle.add_events(events)

self._watch_handle = await self._cluster_storage.watch(
self.worker_key_prefix)

async def worker_event_loop():
logger.info(
f"Start watching worker events with {len(workers)} existing workers"
)
for worker_info in workers:
await on_event(worker_info, WatchEventType.SET)
while True:
try:
worker_events = await self._watch_handle.drain()
for event in worker_events:
worker_info = self._parse_worker_info(event)
await on_event(worker_info, event.event_type)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(
f"Error updating routers by worker events: {e}")
await asyncio.sleep(1)
logger.info("Stop watching worker events")

if on_event:
self._watch_task = asyncio.create_task(worker_event_loop())
return workers

async def unwatch_workers(self) -> None:
if self._watch_handle:
await self._cluster_storage.unwatch(self.worker_key_prefix)
self._watch_handle = None
if self._watch_task:
self._watch_task.cancel()
self._watch_task = None

async def get_worker_events(
self) -> List[Tuple[WorkerInfo, WatchEventType]]:
Expand Down
295 changes: 295 additions & 0 deletions tensorrt_llm/serve/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# yapf: disable
import asyncio
import traceback
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Type

import aiohttp

from tensorrt_llm.llmapi.disagg_utils import ServerRole
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
UCompletionRequest,
UCompletionResponse,
)
from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector
from tensorrt_llm.serve.responses_utils import (
ResponseHooks,
UCompletionResponseOrGenerator,
get_steady_clock_now_in_seconds,
)
from tensorrt_llm.serve.router import Router

# yapf: enable


class OpenAIClient(ABC):
async def send_request(
self,
request: UCompletionRequest,
server: Optional[str] = None,
hooks: Optional[ResponseHooks] = None,
) -> UCompletionResponseOrGenerator:
if isinstance(request, CompletionRequest):
return await self._send_request(
"v1/completions", request, CompletionResponse, server, hooks
)
elif isinstance(request, ChatCompletionRequest):
return await self._send_request(
"v1/chat/completions", request, ChatCompletionResponse, server, hooks
)
else:
raise ValueError(f"Invalid request type: {type(request)}")

@abstractmethod
async def _send_request(
self,
endpoint: str,
request: UCompletionRequest,
response_type: Type[UCompletionResponse],
server: Optional[str] = None,
hooks: Optional[ResponseHooks] = None,
) -> UCompletionResponseOrGenerator:
"""Send a request to the server and return the response and the body generator.

The request is finished (in routers) when the generator is exhausted or there is an error.
"""
...

@abstractmethod
async def collect_metrics(self) -> Dict[str, Any]: ...

@abstractmethod
async def check_ready(self) -> Tuple[List[str], List[str]]:
"""Return the list of ready servers and the list of unready servers."""
...

async def shutdown(self) -> None: ...

@abstractmethod
async def _finish_request(self, request: UCompletionRequest) -> None:
"""Finish the request in the router."""
...


class OpenAIHttpClient(OpenAIClient):
def __init__(
self,
router: Router,
role: ServerRole,
timeout_secs: int = 180,
max_retries: int = 1,
retry_interval_sec: int = 1,
session: Optional[aiohttp.ClientSession] = None,
):
self._router = router
self._role = role
self._metrics_collector = ClientMetricsCollector(role)
self._session = session or aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=False),
timeout=aiohttp.ClientTimeout(total=timeout_secs),
)
self._max_retries = max_retries
self._retry_interval_sec = retry_interval_sec

async def _send_request(
self,
endpoint: str,
request: UCompletionRequest,
response_type: Type[UCompletionResponse],
server: Optional[str] = None,
hooks: Optional[ResponseHooks] = None,
) -> UCompletionResponseOrGenerator:
if server is None:
server, _ = await self._router.get_next_server(request)
url = f"http://{server}/{endpoint}"
logger.debug(
f"Sending {self._role} request {request.disaggregated_params.ctx_request_id} to {url}"
)
try:
self._metrics_collector.total_requests.inc()
resp_generator = self._post_with_retry(server, url, request, hooks)
if request.stream:
# return the response generator, the request is not done yet
return resp_generator
else:
# consume the generator to get the response and return it directly when it's not streaming
response = None
async for resp_json in resp_generator:
response = response_type(**resp_json)
if hooks:
if self._role == ServerRole.CONTEXT:
hooks.on_ctx_resp(server, response)
else:
hooks.on_first_token(server, request)
hooks.on_resp_done(server, request, response)
return response
except Exception:
self._metrics_collector.error_requests.inc()
# finish the request upon error
await self._finish_request(request)
raise

async def _post_with_retry(
self,
server: str,
url: str,
request: UCompletionRequest,
hooks: Optional[ResponseHooks] = None,
) -> AsyncGenerator[Any, None]:
json_data = request.model_dump(exclude_unset=True)
is_stream = request.stream
for attempt in range(self._max_retries + 1):
try:
start_time = get_steady_clock_now_in_seconds()
async with self._session.post(url, json=json_data) as http_response:
content_type = http_response.headers.get("Content-Type", "")
if not is_stream and "text/event-stream" in content_type:
raise ValueError(
"Received an event-stream although request stream was False"
)
if is_stream:
# do NOT return generator directly here or the response will go
# out of scope and get destroyed
async for line in self._response_generator(
request, http_response, start_time, server, hooks
):
yield line
# don't finish the request here since the response generator is not done yet
else:
http_response.raise_for_status()
response_dict = await http_response.json()
# yield here since python forbids return statements in async generators
yield response_dict
# finish the request after the successful response
await self._finish_request(request)
break # break and skip retries if the whole response is processed without exception
except (aiohttp.ClientError, OSError) as e:
if attempt == self._max_retries:
logger.error(
f"Client error to {url}: {e} - last retry {attempt} of {self._max_retries}"
"failed",
traceback.format_exc(),
)
raise
logger.error(
f"{self._role} client error to {url}: {e} - retry {attempt} of {self._max_retries}",
traceback.format_exc(),
)
await asyncio.sleep(self._retry_interval_sec)
self._metrics_collector.retry_requests.inc()
except Exception as e:
logger.error(
f"Unexpected error while processing {self._role} request to {url}: {e}"
)
raise

async def _response_generator(
self,
request: UCompletionRequest,
http_response: aiohttp.ClientResponse,
start_time: float,
server: str,
hooks: Optional[ResponseHooks] = None,
) -> AsyncGenerator[Any, None]:
assert request.stream, "Request is not streaming"
assert "text/event-stream" in http_response.headers.get("Content-Type", ""), (
"Response is not streaming"
)
try:
last_token_time = start_time
i = 0
async for line in http_response.content.iter_any():
now_time = get_steady_clock_now_in_seconds()
if i == 0:
if hooks:
hooks.on_first_token(server, request)
self._metrics_collector.first_token_latency_seconds.observe(
now_time - last_token_time
)
else:
self._metrics_collector.per_token_latency_seconds.observe(
now_time - last_token_time
)
i += 1
if line:
yield line
await asyncio.sleep(0)
last_token_time = now_time

if hooks:
hooks.on_resp_done(server, request, None)
self._metrics_collector.completed_requests.inc()
self._metrics_collector.complete_latency_seconds.observe(
get_steady_clock_now_in_seconds() - start_time
)
except aiohttp.ClientError as e:
# a client error is expected when the response stream is done if the connector has close=True
logger.error(f"{self._role} client {server} error: {e}")
self._metrics_collector.error_requests.inc()
raise
except Exception:
self._metrics_collector.error_requests.inc()
raise
finally:
# finish the request after streaming response is done or error is raised
await self._finish_request(request)

async def _finish_request(self, request: UCompletionRequest) -> None:
await self._router.finish_request(request)

async def collect_metrics(self) -> Dict[str, Any]:
metrics = {}
for server in self._router.servers:
try:
async with self._session.get(f"http://{server}/perf_metrics") as response:
metrics[server] = await response.json()
except Exception:
logger.error(f"Failed to collect metrics from {server}")
continue
return metrics

async def shutdown(self) -> None:
await self._session.close()

async def check_ready(self) -> Tuple[List[str], List[str]]:
return await OpenAIHttpClient.check_ready_for_servers(self._session, self._router.servers)

@staticmethod
async def check_ready_for_servers(
session: aiohttp.ClientSession, servers: List[str]
) -> Tuple[List[str], List[str]]:
async def check_server_ready(server: str) -> bool:
try:
url = (
f"{server}/health"
if server.startswith("http://")
else f"http://{server}/health"
)
async with session.get(url) as response:
return response.status == 200
except Exception:
return False

servers_ready = await asyncio.gather(*[check_server_ready(server) for server in servers])
return [server for server, ready in zip(servers, servers_ready) if ready], [
server for server, ready in zip(servers, servers_ready) if not ready
]
Loading