Skip to content

Commit b419e60

Browse files
committed
fix failed CI tests
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 57451e6 commit b419e60

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

tensorrt_llm/serve/openai_client.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
UCompletionRequest,
3030
UCompletionResponse,
3131
)
32-
from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector, DisaggPerfMetricsCollector
32+
from tensorrt_llm.serve.perf_metrics import ClientMetricsCollector
3333
from tensorrt_llm.serve.responses_utils import (
3434
ResponseHooks,
3535
UCompletionResponseOrGenerator,
@@ -93,13 +93,13 @@ def __init__(
9393
client_type: str,
9494
timeout_secs: int = 180,
9595
max_retries: int = 1,
96-
perf_metrics_collector: DisaggPerfMetricsCollector = None,
96+
session: Optional[aiohttp.ClientSession] = None,
9797
):
9898
assert client_type in ["ctx", "gen"]
9999
self._router = router
100100
self._client_type = client_type
101101
self._metrics_collector = ClientMetricsCollector(client_type)
102-
self._session = aiohttp.ClientSession(
102+
self._session = session or aiohttp.ClientSession(
103103
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=False),
104104
timeout=aiohttp.ClientTimeout(total=timeout_secs),
105105
)
@@ -263,16 +263,25 @@ async def shutdown(self) -> None:
263263
await self._session.close()
264264

265265
async def check_ready(self) -> Tuple[List[str], List[str]]:
266+
return await OpenAIHttpClient.check_ready_for_servers(self._session, self._router.servers)
267+
268+
@staticmethod
269+
async def check_ready_for_servers(
270+
session: aiohttp.ClientSession, servers: List[str]
271+
) -> Tuple[List[str], List[str]]:
266272
async def check_server_ready(server: str) -> bool:
267273
try:
268-
async with self._session.get(f"http://{server}/health") as response:
274+
url = (
275+
f"{server}/health"
276+
if server.startswith("http://")
277+
else f"http://{server}/health"
278+
)
279+
async with session.get(url) as response:
269280
return response.status == 200
270281
except Exception:
271282
return False
272283

273-
servers_ready = await asyncio.gather(
274-
*[check_server_ready(server) for server in self._router.servers]
275-
)
276-
return [server for server, ready in zip(self._router.servers, servers_ready) if ready], [
277-
server for server, ready in zip(self._router.servers, servers_ready) if not ready
284+
servers_ready = await asyncio.gather(*[check_server_ready(server) for server in servers])
285+
return [server for server, ready in zip(servers, servers_ready) if ready], [
286+
server for server, ready in zip(servers, servers_ready) if not ready
278287
]

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def setup(self) -> None:
226226
await self._gen_router.start_server_monitoring(
227227
self._metadata_config.refresh_interval
228228
)
229-
await self._wait_for_servers_ready()
229+
await self._wait_for_all_servers_ready()
230230

231231
async def teardown(self) -> None:
232232
await self._ctx_client.shutdown()
@@ -239,7 +239,7 @@ async def teardown(self) -> None:
239239
await self._ctx_router.stop_server_monitoring()
240240
await self._gen_router.stop_server_monitoring()
241241

242-
async def _wait_for_servers_ready(self) -> None:
242+
async def _wait_for_all_servers_ready(self) -> None:
243243
async def check_servers_ready():
244244
elapsed_time = 0
245245
interval = 3

tests/integration/defs/disaggregated/test_workers.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from transformers import AutoTokenizer
1515

1616
from tensorrt_llm import logger
17-
from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer
17+
from tensorrt_llm.serve.openai_client import OpenAIHttpClient
1818
from tensorrt_llm.serve.openai_protocol import (CompletionRequest,
1919
DisaggregatedParams)
2020
from tensorrt_llm.serve.router import (KvCacheAwareRouter,
@@ -66,6 +66,34 @@ def run_disaggregated_workers(
6666
DEFAULT_TIMEOUT_REQUEST = 180
6767

6868

69+
async def wait_until_all_servers_ready(
70+
session: aiohttp.ClientSession,
71+
servers: List[str],
72+
server_start_timeout_secs: int = 180,
73+
) -> None:
74+
75+
async def check_all_servers_ready():
76+
elapsed_time = 0
77+
interval = 3
78+
while elapsed_time < server_start_timeout_secs:
79+
_, unready_servers = await OpenAIHttpClient.check_ready_for_servers(
80+
session, servers)
81+
if len(unready_servers) == 0:
82+
return
83+
await asyncio.sleep(interval)
84+
elapsed_time += interval
85+
logger.info(
86+
f"[{elapsed_time}] Waiting for servers, {unready_servers}...")
87+
88+
try:
89+
await asyncio.wait_for(check_all_servers_ready(),
90+
timeout=server_start_timeout_secs)
91+
except asyncio.TimeoutError:
92+
raise TimeoutError(
93+
f"Timeout waiting for all servers to be ready in {server_start_timeout_secs} seconds"
94+
)
95+
96+
6997
class BasicWorkerTester:
7098

7199
def __init__(self,
@@ -82,9 +110,9 @@ async def new_session(self):
82110
session = aiohttp.ClientSession(
83111
connector=aiohttp.TCPConnector(force_close=True),
84112
timeout=aiohttp.ClientTimeout(total=self.req_timeout_secs))
85-
await OpenAIDisaggServer.wait_for_all_servers_ready(
86-
session, self.ctx_servers, self.gen_servers,
87-
self.server_start_timeout_secs)
113+
await wait_until_all_servers_ready(session,
114+
self.ctx_servers + self.gen_servers,
115+
self.server_start_timeout_secs)
88116
return session
89117

90118
async def send_request(self, session: aiohttp.ClientSession, url: str,

0 commit comments

Comments
 (0)