Skip to content

Commit 8845ae8

Browse files
authored
[Python] Add Microserving code example (#3089)
This commit adds code examples for Microserving.
1 parent cf7ae82 commit 8845ae8

File tree

8 files changed

+406
-121
lines changed

8 files changed

+406
-121
lines changed

cpp/serve/config.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,13 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
186186
} else {
187187
return TResult::Error("Unknown grammar execution mode " + grammar_execution_mode);
188188
}
189-
Result<DisaggConfig> disagg_config =
190-
DisaggConfig::FromJSON(json::Lookup<picojson::object>(config, "disagg_config"));
191-
if (disagg_config.IsErr()) {
192-
return TResult::Error(disagg_config.UnwrapErr());
189+
if (auto disagg_config_obj = json::LookupOptional<picojson::object>(config, "disagg_config")) {
190+
Result<DisaggConfig> disagg_config = DisaggConfig::FromJSON(disagg_config_obj.value());
191+
if (disagg_config.IsErr()) {
192+
return TResult::Error(disagg_config.UnwrapErr());
193+
}
194+
res.disagg_config = disagg_config.Unwrap();
193195
}
194-
res.disagg_config = disagg_config.Unwrap();
195196
return TResult::Ok(res);
196197
}
197198

docs/index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a
6565
install/gpu.rst
6666
install/emcc.rst
6767

68+
.. toctree::
69+
:maxdepth: 1
70+
:caption: Microserving API
71+
:hidden:
72+
73+
microserving/tutorial.rst
74+
6875
.. toctree::
6976
:maxdepth: 1
7077
:caption: Community
@@ -80,3 +87,4 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a
8087
:hidden:
8188

8289
privacy.rst
90+

docs/microserving/tutorial.rst

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
Implement LLM Cross-engine Orchestration Patterns
2+
======================================================================
3+
4+
In this tutorial, we will introduce how to implement LLM cross-engine
5+
orchestration patterns, like prefill-decode disaggregation, in MLC-LLM
6+
via microserving API. Aiming to make disaggregated serving programmable,
7+
MicroServing provides a new RISC-style approach to design LLM serving
8+
API at sub-request level. It enables programmable cross-engine serving
9+
patterns in a few lines of python code. For more information of
10+
microserving API, check out
11+
https://blog.mlc.ai/2025/01/07/microserving-llm-engines.
12+
13+
Below is an example of prefill-decode disaggregation implementation. An
14+
LLM cross-engine orchestration pattern is implemented in a router, which
15+
dispatches original OpenAI-style completion requests to a chain of
16+
microserving API calls. In this code example, we create a subclass of
17+
Router (which includes wrappers for calling microserving APIs), and
18+
override ``translate_request`` function. The ``translate_request``
19+
function takes in a request and a unique identifier of the request
20+
(``request_id``), and returns an AsyncGenerator of response. We launch
21+
the CustomRouter and 2 engines, each of which has tensor parallel degree
22+
2. Engine 0 is prefill engine and engine 1 is decode engine.
23+
24+
.. code:: python
25+
26+
from mlc_llm.router import Router
27+
from mlc_llm.protocol import openai_api_protocol
28+
from typing import Any, AsyncGenerator
29+
from mlc_llm.serve.entrypoints import microserving_entrypoints
30+
from mlc_llm.interface.router import serve
31+
32+
import aiohttp
33+
34+
class CustomRouter(Router):
35+
async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -&gt; AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
36+
pass
37+
38+
39+
serve(
40+
model="/path/to/model", # replace this with actual path
41+
model_lib="/path/to/model_lib", # replace this with actual path
42+
router_host="127.0.0.1",
43+
router_port=9123,
44+
endpoint_hosts=["127.0.0.1", "127.0.0.1"],
45+
endpoint_ports=[9124,9125],
46+
endpoint_num_gpus=[2,2],
47+
enable_prefix_cache=False,
48+
router_type=CustomRouter,
49+
)
50+
51+
In the ``translate_request`` function, we first assign ``request_id`` to
52+
request.user, and later the request id will be passed as an argument to
53+
the microserving API.
54+
55+
.. code:: python
56+
57+
# we will pass request_id as an argument in microserving API calls
58+
request.user = request_id
59+
60+
61+
Next, call ``prep_recv`` on the decode engine to prepare KV entries for
62+
receiving from remote. ``end=-1`` means that we will let the prefill
63+
engine prefill all except the last token, which makes sure that the
64+
prefill engine does not need sampling logic. ``prep_recv`` returns
65+
address to receive KV from remote and matched prefix length. For
66+
simplicity, we do not enable prefix cache in the tutorial, so we only
67+
need the kv address here.
68+
69+
.. code:: python
70+
71+
async with aiohttp.ClientSession(
72+
timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True
73+
) as session:
74+
decode_start = len(request.prompt) -1
75+
# 1. Ask decode engine to prepare KV entries to receive from prefill engine
76+
prep_recv_request = microserving_entrypoints.PrepRecvRequest(
77+
**request.model_dump(), end=decode_start
78+
)
79+
(
80+
kv_addr_info,
81+
_,
82+
) = await self.send_prepare_receive(
83+
session=session,
84+
request=prep_recv_request,
85+
server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine
86+
)
87+
88+
Then, call ``remote_send`` on the prefill engine to compute and send KV
89+
to decode engine. ``recv_rank=self.device_id_starts[1]`` means that we
90+
are sending KV to engine 1 (decode engine).
91+
92+
.. code:: python
93+
94+
95+
# 2. Ask prefill engine to send KV to decode engine
96+
remote_send_request = microserving_entrypoints.RemoteSendRequest(
97+
**request.model_dump(),
98+
begin=0,
99+
end=decode_start,
100+
kv_addr_info=kv_addr_info,
101+
recv_rank=self.device_id_starts[1], # the rank of decode engine
102+
)
103+
await self.send_remote_send(
104+
session=session,
105+
request=remote_send_request,
106+
server_url=self.server_urls[0], # prefill engine
107+
)
108+
109+
Finally, call ``start_generate`` on the decode engine to start
110+
generating tokens. ``begin=decode_start`` means we will prefill the last
111+
token in the prompt and start decoding. Notably, the decode process of
112+
the request may be preempted. In such case, we yield None, so that the
113+
router will rerun the ``translate_request`` function.
114+
115+
.. code:: python
116+
117+
# 3. Start decoding
118+
start_generate_request = microserving_entrypoints.StartGenerateRequest(
119+
**request.model_dump(),
120+
begin=decode_start,
121+
)
122+
async for response in self.send_start_generate(
123+
session=session,
124+
request=start_generate_request,
125+
server_url=self.server_urls[1],
126+
):
127+
if len(response.choices) &gt; 0:
128+
finish_reason = response.choices[0].finish_reason
129+
if finish_reason == "preempt":
130+
yield None
131+
yield response
132+
133+
Bringing everything together, the complete code is as below:
134+
135+
.. code:: python
136+
137+
from mlc_llm.router import Router
138+
from mlc_llm.protocol import openai_api_protocol
139+
from typing import Any, AsyncGenerator
140+
from mlc_llm.serve.entrypoints import microserving_entrypoints
141+
from mlc_llm.interface.router import serve
142+
143+
import aiohttp
144+
class CustomRouter(Router):
145+
async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -&gt; AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
146+
# we will pass request_id as an argument in microserving API calls
147+
request.user = request_id
148+
149+
async with aiohttp.ClientSession(
150+
timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True
151+
) as session:
152+
decode_start = len(request.prompt) -1
153+
# 1. Ask decode engine to prepare KV entries to receive from prefill engine
154+
prep_recv_request = microserving_entrypoints.PrepRecvRequest(
155+
**request.model_dump(), end=decode_start
156+
)
157+
(
158+
kv_addr_info,
159+
_,
160+
) = await self.send_prepare_receive(
161+
session=session,
162+
request=prep_recv_request,
163+
server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine
164+
)
165+
# 2. Ask prefill engine to send KV to decode engine
166+
remote_send_request = microserving_entrypoints.RemoteSendRequest(
167+
**request.model_dump(),
168+
begin=0,
169+
end=decode_start,
170+
kv_addr_info=kv_addr_info,
171+
recv_rank=self.device_id_starts[1], # the rank of decode engine
172+
)
173+
await self.send_remote_send(
174+
session=session,
175+
request=remote_send_request,
176+
server_url=self.server_urls[0], # prefill engine
177+
)
178+
# 3. Start decoding
179+
start_generate_request = microserving_entrypoints.StartGenerateRequest(
180+
**request.model_dump(),
181+
begin=decode_start,
182+
)
183+
async for response in self.send_start_generate(
184+
session=session,
185+
request=start_generate_request,
186+
server_url=self.server_urls[1],
187+
):
188+
if len(response.choices) &gt; 0:
189+
finish_reason = response.choices[0].finish_reason
190+
if finish_reason == "preempt":
191+
yield None
192+
yield response
193+
194+
195+
serve(
196+
model="/path/to/model", # replace this with actual path
197+
model_lib="/path/to/model_lib", # replace this with actual path
198+
router_host="127.0.0.1",
199+
router_port=9123,
200+
endpoint_hosts=["127.0.0.1", "127.0.0.1"],
201+
endpoint_ports=[9124,9125],
202+
endpoint_num_gpus=[2,2],
203+
enable_prefix_cache=False,
204+
router_type=CustomRouter,
205+
)
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from mlc_llm.router import Router
2+
from mlc_llm.protocol import openai_api_protocol
3+
from typing import Any, AsyncGenerator
4+
from mlc_llm.serve.entrypoints import microserving_entrypoints
5+
from mlc_llm.interface.router import serve
6+
7+
import aiohttp
8+
class CustomRouter(Router):
9+
async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
10+
# we will pass request_id as an argument in microserving API calls
11+
request.user = request_id
12+
13+
async with aiohttp.ClientSession(
14+
timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True
15+
) as session:
16+
decode_start = len(request.prompt) -1
17+
# 1. Ask decode engine to prepare KV entries to receive from prefill engine
18+
prep_recv_request = microserving_entrypoints.PrepRecvRequest(
19+
**request.model_dump(), end=decode_start
20+
)
21+
(
22+
kv_addr_info,
23+
_,
24+
) = await self.send_prepare_receive(
25+
session=session,
26+
request=prep_recv_request,
27+
server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine
28+
)
29+
# 2. Ask prefill engine to send KV to decode engine
30+
remote_send_request = microserving_entrypoints.RemoteSendRequest(
31+
**request.model_dump(),
32+
begin=0,
33+
end=decode_start,
34+
kv_addr_info=kv_addr_info,
35+
recv_rank=self.device_id_starts[1], # the rank of decode engine
36+
)
37+
await self.send_remote_send(
38+
session=session,
39+
request=remote_send_request,
40+
server_url=self.server_urls[0], # prefill engine
41+
)
42+
# 3. Start decoding
43+
start_generate_request = microserving_entrypoints.StartGenerateRequest(
44+
**request.model_dump(),
45+
begin=decode_start,
46+
)
47+
async for response in self.send_start_generate(
48+
session=session,
49+
request=start_generate_request,
50+
server_url=self.server_urls[1],
51+
):
52+
if len(response.choices) > 0:
53+
finish_reason = response.choices[0].finish_reason
54+
if finish_reason == "preempt":
55+
yield None
56+
yield response
57+
58+
59+
serve(
60+
model="/opt/dlami/nvme/models/Llama-3.1-8B-Instruct-q0f16-MLC", # replace this with actual path
61+
model_lib="/opt/dlami/nvme/models/Llama-3.1-8B-Instruct-q0f16-MLC/lib_disagg.so", # replace this with actual path
62+
router_host="127.0.0.1",
63+
router_port=9123,
64+
endpoint_hosts=["127.0.0.1", "127.0.0.1"],
65+
endpoint_ports=[9124,9125],
66+
endpoint_num_gpus=[2,2],
67+
enable_prefix_cache=False,
68+
router_type=CustomRouter,
69+
)

python/mlc_llm/interface/router.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pylint: disable=fixme
44
from http import HTTPStatus
5-
from typing import AsyncGenerator, List, Literal, Optional
5+
from typing import AsyncGenerator, List, Literal, Optional, Type
66

77
import fastapi
88
import uvicorn
@@ -23,12 +23,13 @@ def serve(
2323
endpoint_ports: List[int],
2424
endpoint_num_gpus: List[int],
2525
enable_prefix_cache: bool,
26-
router_mode: Literal["disagg", "round-robin"],
27-
pd_balance_factor: float,
26+
router_mode: Literal["disagg", "round-robin"] = "round-robin",
27+
pd_balance_factor: float = 0.0,
28+
router_type: Type[Router] = Router,
2829
): # pylint: disable=too-many-arguments
2930
"""Start the router with the specified configuration."""
3031
# 1. Instantiate router
31-
router = Router(
32+
router = router_type(
3233
model=model,
3334
model_lib=model_lib,
3435
hosts=endpoint_hosts,

python/mlc_llm/protocol/microserving_protocol.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,14 @@ class PrepRecvRequest(CompletionRequest):
1616
The entries of this KV range will be allocated on the decode instance.
1717
"""
1818

19-
kv_window_end: int
19+
end: int
2020

2121

2222
class PrepRecvResponse(BaseModel):
2323
"""The response body for prep_recv request in MicroServing.
2424
2525
Attributes
2626
----------
27-
prompt_length : int
28-
The length of the request prompt in tokens.
29-
3027
prefix_matched_length : int
3128
The matched common prefix length on the decode instance when
3229
prefix cache is enabled, or 0 if there is no prefix cache.
@@ -35,9 +32,8 @@ class PrepRecvResponse(BaseModel):
3532
The metadata of the KV range on the destination decode instance.
3633
"""
3734

38-
prompt_length: int
39-
prefix_matched_length: int
4035
kv_append_metadata: str
36+
prefix_matched_length: int
4137

4238

4339
class RemoteSendRequest(CompletionRequest):
@@ -58,10 +54,10 @@ class RemoteSendRequest(CompletionRequest):
5854
The node group offset of the destination decode instance.
5955
"""
6056

61-
kv_window_begin: int
62-
kv_window_end: int
63-
kv_append_metadata: str
64-
dst_group_offset: int
57+
begin: int
58+
end: int
59+
kv_addr_info: str
60+
recv_rank: int
6561

6662

6763
class StartGenerateRequest(CompletionRequest):
@@ -73,4 +69,4 @@ class StartGenerateRequest(CompletionRequest):
7369
Denote the start of the KV range to prefill on the decode instance.
7470
"""
7571

76-
kv_window_begin: int
72+
begin: int

0 commit comments

Comments
 (0)