Skip to content

Commit 24167d0

Browse files
authored
[TRTLLM-8431][doc] update public doc and example, add etcd auto-scaling tests (#8602)
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 227c288 commit 24167d0

File tree

8 files changed

+198
-43
lines changed

8 files changed

+198
-43
lines changed

examples/disaggregated/README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,39 @@ srun -A <account> -p <partition> -t <time> \
204204
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/simple_example/).
205205

206206

207-
## Dynamic scaling (Prototype)
207+
## Dynamic scaling
208+
209+
### Service discovery method
210+
211+
Disaggregated server also supports dynamic service-discovery and auto-scaling of context/generation servers. This can be achieved by setting `disagg_cluster` section in the configurations of both context/generation servers and disagg-server. In this case, the context/generation servers must include an extra command line of `--server-role=[context|generation]`, also the `context/genration_servers` section of disaggregated server must be removed. You can simplify context/generation servers' config section by only passing `--disagg_cluster_uri=<disagg_cluster_uri>` in the command line (but disaggregated server's config must have this section). The omitted fields will use the defaults shown below.
212+
213+
```yaml
214+
disagg_cluster:
215+
cluster_uri: <your_cluster_uri>
216+
cluster_name: ""
217+
minimal_instances:
218+
context_servers: 1
219+
generation_servers: 1
220+
heartbeat_interval_sec: 5
221+
inactive_interval_sec: 10
222+
```
223+
- `cluster_uri`: the http address of disagg-server like `http://<your-disagg-server-host>:<your-disagg-server-port>` or a pre-configured Etcd server address like `etcd://<your-etcd-host>:2379`.
224+
- `cluster_name` : optional namespace to isolate multiple disagg-clusters in Etcd.
225+
- `minimal_instances`: the equivalence of `num_instances` in the auto-scaling concept, disagg-server will reject requests when
226+
the active context/generation servers is below the corresponding threshold.
227+
- `heartbeat_interval_sec`: frequency at which context/generation servers send heartbeats to the disagg-server.
228+
- `inactive_interval_sec`: A server is marked inactive if no heartbeat is received within this interval (set higher than the heartbeat interval).
229+
230+
Note that the disaggregated server and all the context/generation servers should have the same `disagg_cluster` configuration values, or the disaggregated server may not be able to keep alive or detect inactivity the other servers properly. If `disagg_cluster` section is specified,
231+
232+
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/service_discovery_example/).
233+
234+
#### Dynamically adding servers
235+
236+
To add servers dynamically, you can start more context/generation workers with the same `disagg_cluster`, then the disaggregated server can discover the new servers and dispatch requests to them automatically. If a context/generation server becomes inactive, the disaggregated server will also detect this and stop routing requests to it.
237+
238+
239+
### Metadata server method (Prototype)
208240

209241
Currently, trtllm supports dynamic addition and removal of servers by leveraging ETCD. To enable this feature, you should start the context and generation servers with an additional flag ```--metadata_server_config_file``` and ```--server_role```.
210242
Before launching the context and generation servers, you should first start the ETCD server. By default, the ETCD server listens for client requests at ```localhost:2379```.
@@ -240,7 +272,7 @@ refersh_interval: 10.0
240272

241273
The ```hostname``` and ```port``` must match those used when starting the ETCD server. The ```health_check_timeout``` parameter specifies how long a server will be considered dead if no healthy response is received. By default, trtllm will perform two checks before marking a server as dead. The ```refresh_interval``` parameter determines how often the latest server list is fetched from the ETCD server.
242274

243-
### Dynamically adding servers
275+
#### Dynamically adding servers
244276

245277
Users can add servers by directly launching them with trtllm-serve. For example, you can start an additional generation server as follows:
246278

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
bin/bash
2+
#SBATCH --partition=${partition}
3+
#SBATCH --account=${account}
4+
#SBATCH --job-name=${job_name}
5+
#SBATCH --time=02:00:00
6+
7+
container_image="${container_image:-}"
8+
mount_paths="${mount_paths:-}"
9+
work_path="${work_path:-}"
10+
enable_etcd="${enable_etcd:-0}"
11+
disagg_port="8000"
12+
ctx_port="8001"
13+
gen_port="8002"
14+
15+
# use the first node as the disaggregated server node
16+
disagg_server_node=$(head -n 1 <(scontrol show hostnames $SLURM_JOB_NODELIST))
17+
18+
if [[ "$enable_etcd" == "1" ]]; then
19+
# you can optionally launch a etcd server, the container image must have etcd installed
20+
disagg_cluster_uri="etcd://${disagg_server_node}:2379"
21+
srun --container-image=${container_image} \
22+
--container-mounts=${mount_paths} \
23+
-w $disagg_server_node -N 1 --ntasks-per-node=1 \
24+
--mpi=pmix \
25+
bash -c "etcd" &
26+
sleep 5 # wait for etcd to start
27+
else
28+
# or use the disaggregated server's http address as built-in service discovery server
29+
disagg_cluster_uri="http://${disagg_server_node}:${disagg_port}"
30+
fi
31+
32+
cat >${work_path}/disagg_config.yaml << EOL
33+
hostname: localhost
34+
port: ${disagg_port}
35+
backend: pytorch
36+
disagg_cluster:
37+
cluster_uri: ${disagg_cluster_uri}
38+
cluster_name: example_cluster
39+
EOL
40+
41+
cat >${work_path}/ctx_extra-llm-api-config.yaml << EOL
42+
disable_overlap_scheduler: True
43+
cache_transceiver_config:
44+
backend: UCX
45+
max_tokens_in_buffer: 2048
46+
EOL
47+
48+
cat >${work_path}/gen_extra-llm-api-config.yaml << EOL
49+
cache_transceiver_config:
50+
backend: UCX
51+
max_tokens_in_buffer: 2048
52+
EOL
53+
54+
# Launch a proxy without any context/generation servers.
55+
srun --container-image=${container_image} \
56+
--container-mounts=${mount_paths} \
57+
-w $disagg_server_node -N 1 --ntasks-per-node=1 \
58+
--mpi=pmix \
59+
bash -c "trtllm-llmapi-launch trtllm-serve disaggregated -c ${work_path}/disagg_config.yaml" &
60+
61+
# Launch a context with `tp_size=8` using two 4-GPU nodes, and register itself through disagg_cluster_uri
62+
srun --container-image=${container_image} \
63+
--container-mounts=${mount_paths} \
64+
-N 2 --ntasks-per-node=4 \
65+
--mpi=pmix \
66+
bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 8 --host 0.0.0.0 --port ${ctx_port} --extra_llm_api_options ${work_path}/ctx_extra-llm-api-config.yaml --disagg_cluster_uri ${disagg_cluster_uri} --server-role context" &
67+
68+
# Launch a generation with `tp_size=4` using one 4-GPU node.
69+
srun --container-image=${container_image} \
70+
--container-mounts=${mount_paths} \
71+
-N 1 --ntasks-per-node=4 \
72+
--mpi=pmix \
73+
bash -c "trtllm-llmapi-launch trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tp_size 4 --host 0.0.0.0 --port ${gen_port} --extra_llm_api_options ${work_path}/gen_extra-llm-api-config.yaml --disagg_cluster_uri ${disagg_cluster_uri} --server-role generation" &

tensorrt_llm/commands/serve.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,14 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
533533
help=
534534
"The interval of logging metrics in seconds. Set to 0 to disable metrics logging."
535535
)
536-
def disaggregated(config_file: Optional[str],
537-
metadata_server_config_file: Optional[str],
538-
server_start_timeout: int, request_timeout: int,
539-
log_level: str, metrics_log_interval: int):
536+
def disaggregated(
537+
config_file: Optional[str],
538+
metadata_server_config_file: Optional[str],
539+
server_start_timeout: int,
540+
request_timeout: int,
541+
log_level: str,
542+
metrics_log_interval: int,
543+
):
540544
"""Running server in disaggregated mode"""
541545

542546
logger.set_level(log_level)

tensorrt_llm/serve/cluster_storage.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,26 @@ class WatchEvent:
3636

3737
class WatchEventQueue:
3838

39-
def __init__(self, key_prefixes: List[str],
40-
events: asyncio.Queue[WatchEvent]):
39+
def __init__(self, key_prefixes: List[str]):
4140
self.key_prefixes = key_prefixes
42-
self.events = events
41+
self.events = asyncio.Queue()
4342

4443
async def drain(self):
4544
events = []
4645
event = await self.events.get()
47-
logger.debug(f"Draining watch event: {self.events.qsize()}")
4846
events.append(event)
4947
while not self.events.empty():
5048
event = self.events.get_nowait()
5149
events.append(event)
5250
self.events.task_done()
53-
logger.debug(f"after draining watch event: {self.events.qsize()}")
5451
return events
5552

53+
async def add_events(self, events: List[WatchEvent]):
54+
loop = asyncio.get_event_loop()
55+
for event in events:
56+
self.events.put_nowait(event)
57+
loop._write_to_self()
58+
5659

5760
class ClusterStorage(abc.ABC):
5861

@@ -104,17 +107,17 @@ async def get_prefix(self,
104107

105108

106109
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
107-
if cluster_uri.startswith("http"):
110+
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
108111
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
109-
elif cluster_uri.startswith("etcd"):
112+
elif cluster_uri.startswith("etcd://"):
110113
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
111114
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
112115

113116

114117
def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs):
115-
if cluster_uri.startswith("http"):
118+
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
116119
return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs)
117-
elif cluster_uri.startswith("etcd"):
120+
elif cluster_uri.startswith("etcd://"):
118121
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
119122
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
120123

@@ -138,7 +141,11 @@ def key_time():
138141

139142
class HttpClusterStorageServer(ClusterStorage):
140143

141-
def __init__(self, cluster_uri, cluster_name, server: FastAPI = None):
144+
def __init__(self,
145+
cluster_uri,
146+
cluster_name,
147+
server: FastAPI = None,
148+
**kwargs):
142149
self._storage = {}
143150
self._lock = asyncio.Lock()
144151
self._watch_handles = {}
@@ -237,7 +244,7 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
237244
)
238245
else:
239246
self._watch_handles[key_prefix] = WatchEventQueue(
240-
key_prefixes=[key_prefix], events=asyncio.Queue())
247+
key_prefixes=[key_prefix])
241248
return self._watch_handles[key_prefix]
242249

243250
async def unwatch(self, key_prefix: str) -> None:
@@ -291,7 +298,7 @@ async def _check_expired(self):
291298

292299
class HttpClusterStorageClient(ClusterStorage):
293300

294-
def __init__(self, cluster_uri, cluster_name):
301+
def __init__(self, cluster_uri, cluster_name, **kwargs):
295302
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
296303
total=5))
297304
self._cluster_uri = cluster_uri if cluster_uri.startswith(
@@ -393,8 +400,8 @@ def __init__(self,
393400
key_prefix: str,
394401
cancel_event: Callable[[], None] = None):
395402
self.key_prefix = key_prefix
396-
self._cancel_event = cancel_event
397403
self.events = asyncio.Queue()
404+
self._cancel_event = cancel_event
398405

399406
def cancel_event(self):
400407
if self._cancel_event:
@@ -406,7 +413,7 @@ def set_cancel_event(self, cancel_event: Callable[[], None]):
406413
def __del__(self):
407414
self.cancel_event()
408415

409-
def add_event(self, watch_resp):
416+
def add_events_from_resp(self, watch_resp):
410417
try:
411418
for event in watch_resp.events:
412419
# Event type is not in public interface of etcd3
@@ -430,7 +437,8 @@ class Etcd3ClusterStorage(ClusterStorage):
430437
def __init__(self,
431438
cluster_uri: str,
432439
cluster_name: str,
433-
one_single_lease: bool = False):
440+
one_single_lease: bool = False,
441+
**kwargs):
434442
cluster_uri = cluster_uri.replace("etcd://", "")
435443
host, port = cluster_uri.rsplit(":", 1)
436444
self._client = etcd3.client(host, port)
@@ -502,7 +510,7 @@ async def expire(self, key: str, ttl: int) -> bool:
502510
try:
503511
lease = self._get_lease(key, ttl)
504512
# TTL will be ignored since it can only be set when creating a lease
505-
self.client.refresh_lease(lease_id=lease.id)
513+
next(self.client.refresh_lease(lease_id=lease.id), None)
506514
except etcd3.Etcd3Exception as e:
507515
logger.error(f"Error refreshing lease {key}: {e}")
508516
return False
@@ -512,7 +520,7 @@ async def get_prefix(self,
512520
key_prefix: str,
513521
keys_only: bool = False) -> Dict[str, str]:
514522
try:
515-
resp = self.client.get_prefix(key_prefix, keys_only=keys_only)
523+
resp = self.client.get_prefix(key_prefix)
516524
return {
517525
metadata.key.decode("utf-8"):
518526
"" if keys_only else v.decode("utf-8")
@@ -528,7 +536,7 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
528536
return self._watch_handles[key_prefix]
529537
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
530538
watch_id = self.client.add_watch_prefix_callback(
531-
key_prefix, watch_handle.add_event)
539+
key_prefix, watch_handle.add_events_from_resp)
532540
watch_handle.set_cancel_event(
533541
lambda: self.client.cancel_watch(watch_id))
534542
self._watch_handles[key_prefix] = watch_handle

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,21 @@ def worker_key_prefix(self) -> str:
9494

9595
async def watch_workers(self, get_existing_first: bool = True):
9696
workers = []
97+
self._watch_handle = await self._cluster_storage.watch(
98+
self.worker_key_prefix)
9799
if get_existing_first:
98100
# There is a tiny gap between getting existing workers and watching the key,
99101
# which may cause we missing some workers registered in between.
100102
resp = await self._cluster_storage.get_prefix(
101103
self.worker_key_prefix, keys_only=False)
104+
events = []
102105
for worker_id, data in resp.items():
103106
event = WatchEvent(storage_item=StorageItem(key=worker_id,
104107
value=data),
105108
event_type=WatchEventType.SET)
106109
workers.append(self._parse_worker_info(event))
107-
self._watch_handle = await self._cluster_storage.watch(
108-
self.worker_key_prefix)
110+
events.append(event)
111+
await self._watch_handle.add_events(events)
109112
return workers
110113

111114
async def unwatch_workers(self) -> None:

tensorrt_llm/serve/router.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(self, server_role: ServerRole, servers: List[str],
159159
@abstractmethod
160160
def _on_servers_updated(self, old_servers, new_servers):
161161
"""Called when the server list changes. Override in subclasses to handle index resets.
162+
Called with lock already held.
162163
Args:
163164
old_servers: The previous server list
164165
new_servers: The new server list
@@ -639,8 +640,11 @@ async def finish_request(self,
639640
session=session)
640641

641642
def _on_servers_updated(self, old_servers, new_servers):
642-
raise NotImplementedError(
643-
"KvCacheAwareRouter does not support server updates")
643+
for new_server in new_servers:
644+
self._server_state[new_server] = KvCacheAwareServerState(
645+
new_server, self._use_tokens)
646+
for old_server in old_servers:
647+
self._server_state.pop(old_server, None)
644648

645649

646650
def create_router(router_config: Optional[RouterConfig],

0 commit comments

Comments
 (0)