Skip to content

Commit 394b204

Browse files
Initial push
Signed-off-by: Zhongxuan Wang <[email protected]>
1 parent f3f764e commit 394b204

File tree

3 files changed

+61
-17
lines changed

3 files changed

+61
-17
lines changed

components/src/dynamo/vllm/handlers.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
211211
),
212212
"prompt_tokens_details": (
213213
{"cached_tokens": request_output.num_cached_tokens}
214-
if request_output.num_cached_tokens
214+
if request_output.num_cached_tokens is not None
215+
and request_output.num_cached_tokens >= 0
215216
else None
216217
),
217218
}
@@ -241,10 +242,10 @@ async def generate_tokens(
241242
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
242243
if output.finish_reason:
243244
out["finish_reason"] = output.finish_reason
244-
out[
245-
"completion_usage"
246-
] = BaseWorkerHandler._build_completion_usage(
247-
request_output=res
245+
out["completion_usage"] = (
246+
BaseWorkerHandler._build_completion_usage(
247+
request_output=res
248+
)
248249
)
249250
if output.stop_reason:
250251
out["stop_reason"] = output.stop_reason
@@ -349,6 +350,9 @@ async def generate(self, request, context):
349350
request_id = context.id()
350351
logger.debug(f"Prefill Request ID: {request_id}")
351352

353+
# Extract overlap information from router (if present)
354+
overlap_blocks = request.get("estimated_prefix_hit_num_blocks", 0)
355+
352356
# Extract and decode multimodal data if present
353357
multi_modal_data = await self._extract_multimodal_data(request)
354358

@@ -391,13 +395,18 @@ async def generate(self, request, context):
391395

392396
token_ids = res.outputs[0].token_ids if res.outputs else []
393397

398+
# Build disaggregated_params with KV transfer params and router overlap
399+
disaggregated_params = {}
400+
if res.kv_transfer_params:
401+
disaggregated_params["kv_transfer_params"] = (
402+
res.kv_transfer_params
403+
)
404+
# Include router's overlap calculation for PrefillRouter
405+
disaggregated_params["overlap_blocks"] = overlap_blocks
406+
394407
output: Dict[str, Any] = {
395408
"token_ids": list(token_ids),
396-
"disaggregated_params": (
397-
{"kv_transfer_params": res.kv_transfer_params}
398-
if res.kv_transfer_params
399-
else None
400-
),
409+
"disaggregated_params": disaggregated_params,
401410
"completion_usage": BaseWorkerHandler._build_completion_usage(
402411
request_output=res
403412
),

lib/llm/src/entrypoint/input/common.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,9 @@ where
267267
};
268268

269269
// Use the provided prefill chooser, or create a disabled one if not provided
270-
let prefill_chooser =
271-
prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg));
270+
let block_size = card.kv_cache_block_size;
271+
let prefill_chooser = prefill_chooser
272+
.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg, block_size));
272273
let prefill_op = prefill_chooser.into_operator();
273274

274275
// Link with prefill chooser including backward edge for response flow

lib/llm/src/kv_router/prefill_router.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,18 @@ pub struct PrefillRouter {
5757
cancel_token: CancellationToken,
5858
router_mode: RouterMode,
5959
enforce_disagg: bool,
60+
block_size: u32,
6061
}
6162

6263
impl PrefillRouter {
6364
/// Create a disabled prefill router that will never activate (passthrough only)
64-
pub fn disabled(router_mode: RouterMode, enforce_disagg: bool) -> Arc<Self> {
65+
pub fn disabled(router_mode: RouterMode, enforce_disagg: bool, block_size: u32) -> Arc<Self> {
6566
Arc::new(Self {
6667
prefill_router: OnceLock::new(),
6768
cancel_token: CancellationToken::new(),
6869
router_mode,
6970
enforce_disagg,
71+
block_size,
7072
})
7173
}
7274

@@ -86,6 +88,7 @@ impl PrefillRouter {
8688
cancel_token: cancel_token.clone(),
8789
router_mode,
8890
enforce_disagg,
91+
block_size: kv_cache_block_size,
8992
});
9093

9194
// Spawn background task to wait for activation
@@ -180,7 +183,8 @@ impl PrefillRouter {
180183
async fn call_prefill(
181184
&self,
182185
request: SingleIn<PreprocessedRequest>,
183-
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
186+
_block_size: u32,
187+
) -> Result<(PrefillResult, Option<u64>, u32), PrefillError> {
184188
// Get the prefill router, error if not activated
185189
let Some(prefill_router) = self.prefill_router.get() else {
186190
return Err(PrefillError::NotActivated);
@@ -247,12 +251,22 @@ impl PrefillRouter {
247251
.get("prefill_worker_id")
248252
.and_then(|v| v.as_u64())
249253
});
254+
255+
// Extract overlap_blocks from the response (set by prefill worker)
256+
let overlap_blocks = output
257+
.disaggregated_params
258+
.as_ref()
259+
.and_then(|params| params.get("overlap_blocks"))
260+
.and_then(|v| v.as_u64())
261+
.unwrap_or(0) as u32;
262+
250263
Ok((
251264
PrefillResult {
252265
disaggregated_params,
253266
prompt_tokens_details,
254267
},
255268
prefill_worker_id,
269+
overlap_blocks,
256270
))
257271
}
258272
}
@@ -297,7 +311,7 @@ impl
297311
let prefill_request = prefill_context;
298312

299313
// Attempt prefill
300-
let prefill_result = self.call_prefill(prefill_request).await;
314+
let prefill_result = self.call_prefill(prefill_request, self.block_size).await;
301315

302316
// Abort if cancelled during prefill
303317
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
@@ -310,8 +324,28 @@ impl
310324

311325
// Handle prefill result
312326
match prefill_result {
313-
Ok((prefill_result, prefill_worker_id)) => {
314-
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
327+
Ok((mut prefill_result, prefill_worker_id, overlap_blocks)) => {
328+
// Prefer vLLM's actual cached_tokens over router's estimate
329+
// vLLM queries the actual KV cache on the prefill worker (ground truth)
330+
// Router's overlap is just a prediction based on its global state
331+
let vllm_cached_tokens = prefill_result
332+
.prompt_tokens_details
333+
.as_ref()
334+
.and_then(|d| d.cached_tokens);
335+
let final_cached_tokens = if let Some(vllm_value) = vllm_cached_tokens {
336+
vllm_value
337+
} else {
338+
overlap_blocks * self.block_size
339+
};
340+
341+
prefill_result.prompt_tokens_details =
342+
Some(dynamo_async_openai::types::PromptTokensDetails {
343+
cached_tokens: Some(final_cached_tokens),
344+
audio_tokens: prefill_result
345+
.prompt_tokens_details
346+
.as_ref()
347+
.and_then(|d| d.audio_tokens),
348+
});
315349

316350
let mut decode_req = req;
317351
// Update request with prefill result

0 commit comments

Comments
 (0)