Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions components/src/dynamo/frontend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def parse_args():
default=True,
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing.",
)
parser.add_argument(
"--router-with-cascade-attention",
action="store_true",
dest="active_kv_reuse",
default=False,
help="KV Router: Enable KV reuse tracking for cascade attention backends. When enabled, tracks actual sequence hashes during decode for KV reuse optimization. By default, KV reuse tracking is disabled during decode.",
)
parser.add_argument(
"--busy-threshold",
type=float,
Expand Down Expand Up @@ -271,6 +278,7 @@ def signal_handler():
router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states,
router_track_active_blocks=flags.router_track_active_blocks,
active_kv_reuse=flags.active_kv_reuse,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
Expand Down
2 changes: 2 additions & 0 deletions docs/router/kv_cache_routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The main KV-aware routing arguments:

- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.

- `--router-with-cascade-attention`: Enables KV reuse tracking during the decode phase. By default (when this flag is not provided), the router does not account for KV reuse during decoding load estimation, as normal decoding attention mechanisms do not make maximal use of KV reuse during decoding (aside from implicit L2 caching). However, cascade attention does benefit from KV reuse during decode. If your backend engine uses cascade attention, enable this flag to optimize routing decisions by tracking sequence hashes for decode-time KV reuse.

- `--busy-threshold`: Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`.

>[!Note]
Expand Down
1 change: 1 addition & 0 deletions launch/dynamo-run/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ impl Flags {
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
None,
None,
),
)
}
Expand Down
1 change: 1 addition & 0 deletions lib/bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ pub unsafe extern "C" fn dynamo_create_worker_selection_pipeline(
None,
None,
None,
Some(false),
))
} else {
None
Expand Down
5 changes: 4 additions & 1 deletion lib/bindings/python/rust/llm/entrypoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, active_kv_reuse=false))]
#[allow(clippy::too_many_arguments)]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
Expand All @@ -51,6 +52,7 @@ impl KvRouterConfig {
router_track_active_blocks: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
active_kv_reuse: bool,
) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
Expand All @@ -61,6 +63,7 @@ impl KvRouterConfig {
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
active_kv_reuse,
},
}
}
Expand Down
9 changes: 9 additions & 0 deletions lib/llm/src/kv_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ pub struct KvRouterConfig {

/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,

/// Whether to track KV reuse during decode phase (default: false)
/// When false, random hashes are used to prevent deduplication
/// When true, actual sequence hashes are tracked for cascade attention
pub active_kv_reuse: bool,
}

impl Default for KvRouterConfig {
Expand All @@ -124,6 +129,7 @@ impl Default for KvRouterConfig {
router_track_active_blocks: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
active_kv_reuse: false,
}
}
}
Expand All @@ -140,6 +146,7 @@ impl KvRouterConfig {
track_active_blocks: Option<bool>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
active_kv_reuse: Option<bool>,
) -> Self {
let default = Self::default();
Self {
Expand All @@ -152,6 +159,7 @@ impl KvRouterConfig {
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
active_kv_reuse: active_kv_reuse.unwrap_or(default.active_kv_reuse),
}
}
}
Expand Down Expand Up @@ -279,6 +287,7 @@ impl KvRouter {
selector,
kv_router_config.router_replica_sync,
consumer_uuid.clone(),
kv_router_config.active_kv_reuse,
)
.await?;

Expand Down
3 changes: 3 additions & 0 deletions lib/llm/src/kv_router/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub struct KvScheduler {
}

impl KvScheduler {
#[allow(clippy::too_many_arguments)]
pub async fn start(
component: Component,
block_size: u32,
Expand All @@ -101,6 +102,7 @@ impl KvScheduler {
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
active_kv_reuse: bool,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instances: Vec<Instance> = instances_rx.borrow().clone();
Expand All @@ -127,6 +129,7 @@ impl KvScheduler {
workers_with_configs.read().await.clone(), // this includes dp_size info
replica_sync,
router_uuid,
active_kv_reuse,
));

// Spawn background task to monitor and update workers_with_configs
Expand Down
36 changes: 31 additions & 5 deletions lib/llm/src/kv_router/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ pub struct ActiveSequences {

/// Set of request IDs to check for expiry
expiry_requests: HashSet<RequestId>,

/// Whether to track KV reuse during decode (false = use random hashes)
#[getter(copy)]
active_kv_reuse: bool,
}

impl ActiveSequences {
/// Create a new SharedSequenceManager instance
pub fn new(block_size: usize) -> Self {
pub fn new(block_size: usize, active_kv_reuse: bool) -> Self {
// TODO: make this not a hard req
assert!(block_size > 1, "block_size must be greater than 1");

Expand All @@ -85,6 +89,7 @@ impl ActiveSequences {
active_tokens: 0,
expiry_timer: Instant::now() + EXPIRY_DURATION,
expiry_requests: HashSet::new(),
active_kv_reuse,
}
}

Expand Down Expand Up @@ -135,7 +140,18 @@ impl ActiveSequences {
self.active_tokens += prefill_tokens;

if let Some(sequence) = token_sequence {
let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence
// When active_kv_reuse is false, replace each block hash with a random hash
let sequence_to_use: Vec<SequenceHash> = if self.active_kv_reuse {
sequence
} else {
// Generate random hashes for each block to prevent KV reuse tracking
sequence
.iter()
.map(|_| (Uuid::new_v4().as_u128() & 0xFFFFFFFFFFFFFFFF) as u64)
.collect()
};

let sequence_with_refs: Vec<(SequenceHash, Rc<()>)> = sequence_to_use
.iter()
.map(|block| (*block, self.touch_block(block)))
.collect();
Expand Down Expand Up @@ -287,6 +303,7 @@ pub struct ActiveSequencesMultiWorker {
component: Component,
router_id: Uuid,
replica_sync: bool,
active_kv_reuse: bool,
}

impl ActiveSequencesMultiWorker {
Expand All @@ -296,6 +313,7 @@ impl ActiveSequencesMultiWorker {
workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
replica_sync: bool,
router_uuid: String,
active_kv_reuse: bool,
) -> Self {
assert!(block_size > 1, "block_size must be greater than 1");

Expand All @@ -319,7 +337,8 @@ impl ActiveSequencesMultiWorker {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
// Create a child cancellation token from the component's runtime
let cancel_token = component.drt().runtime().child_token();
let (sender, handle) = Self::start_worker(block_size, cancel_token);
let (sender, handle) =
Self::start_worker(block_size, active_kv_reuse, cancel_token);
senders.insert(worker, sender);
handles.insert(worker, handle);
}
Expand All @@ -333,6 +352,7 @@ impl ActiveSequencesMultiWorker {
component: component.clone(),
router_id,
replica_sync,
active_kv_reuse,
};

// Start the subscription loop only if replica_sync is enabled
Expand Down Expand Up @@ -365,6 +385,7 @@ impl ActiveSequencesMultiWorker {
/// Helper method to start a worker task
fn start_worker(
block_size: usize,
active_kv_reuse: bool,
cancel_token: CancellationToken,
) -> (
tokio::sync::mpsc::UnboundedSender<UpdateSequences>,
Expand All @@ -380,7 +401,7 @@ impl ActiveSequencesMultiWorker {
.unwrap();

runtime.block_on(async move {
let mut active_sequences = ActiveSequences::new(block_size);
let mut active_sequences = ActiveSequences::new(block_size, active_kv_reuse);
let mut request_rx = request_rx;

loop {
Expand Down Expand Up @@ -598,6 +619,7 @@ impl ActiveSequencesMultiWorker {

let (sender, handle) = Self::start_worker(
self.block_size,
self.active_kv_reuse,
self.component.drt().runtime().child_token(),
);
self.senders.insert(*worker, sender);
Expand Down Expand Up @@ -902,7 +924,7 @@ mod tests {
#[test]
fn test_active_sequences_shared_blocks() {
let block_size = 4;
let mut seq_manager = ActiveSequences::new(block_size);
let mut seq_manager = ActiveSequences::new(block_size, true);

seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0);
assert_eq!(seq_manager.active_blocks(), 3);
Expand Down Expand Up @@ -968,13 +990,15 @@ mod tests {
workers_with_configs.clone(),
true,
Uuid::new_v4().to_string(),
true,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
Uuid::new_v4().to_string(),
true,
));

// Give some time for the subscription loops to start
Expand Down Expand Up @@ -1125,13 +1149,15 @@ mod tests {
workers_with_configs.clone(),
true,
Uuid::new_v4().to_string(),
true,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
Uuid::new_v4().to_string(),
true,
));

// Give some time for the subscription loops to start
Expand Down
Loading