Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ struct Router {
assignment_mode: String,
max_payload_size: usize,
dp_aware: bool,
dp_minimum_tokens_scheduler: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
Expand Down Expand Up @@ -739,6 +740,7 @@ impl Router {
self.server_cert_path.as_ref(),
self.server_key_path.as_ref(),
)
.dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler)
.build()
}
}
Expand All @@ -764,6 +766,7 @@ impl Router {
assignment_mode = String::from("random"),
max_payload_size = 512 * 1024 * 1024,
dp_aware = false,
dp_minimum_tokens_scheduler = false,
api_key = None,
log_dir = None,
log_level = None,
Expand Down Expand Up @@ -870,6 +873,7 @@ impl Router {
assignment_mode: String,
max_payload_size: usize,
dp_aware: bool,
dp_minimum_tokens_scheduler: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
Expand Down Expand Up @@ -985,6 +989,7 @@ impl Router {
assignment_mode,
max_payload_size,
dp_aware,
dp_minimum_tokens_scheduler,
api_key,
log_dir,
log_level,
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/smg/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class Router:
max_tree_size: Maximum size of the approximation tree for cache-aware routing.
Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
dp_minimum_tokens_scheduler: Enable minimum tokens scheduler for data parallel group. Default: False
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When
enabled, the router can manage multiple models simultaneously with per-model
load balancing policies. Default: False
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/src/smg/router_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class RouterArgs:
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
bucket_adjust_interval_secs: int = 5
dp_aware: bool = False
dp_minimum_tokens_scheduler: bool = False
enable_igw: bool = False # Enable IGW (Inter-Gateway) mode for multi-model support
api_key: str | None = None
log_dir: str | None = None
Expand Down Expand Up @@ -382,6 +383,11 @@ def add_cli_args(
action="store_true",
help="Enable data parallelism aware schedule",
)
routing_group.add_argument(
f"--{prefix}dp-minimum-tokens-scheduler",
action="store_true",
help="Enable minimum tokens scheduler for data parallel group",
)
Comment on lines +386 to +390
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update documentation as there is a user-facing change here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

routing_group.add_argument(
f"--{prefix}enable-igw",
action="store_true",
Expand Down
8 changes: 8 additions & 0 deletions crates/protocols/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,14 @@ impl WorkerLoadResponse {
pub fn total_used_tokens(&self) -> i64 {
self.loads.iter().map(|l| l.num_used_tokens as i64).sum()
}

pub fn dp_rank_loads(&self) -> HashMap<isize, isize> {
let mut map = HashMap::new();
for snapshot in &self.loads {
map.insert(snapshot.dp_rank as isize, snapshot.num_used_tokens as isize);
}
map
}
}

/// Individual worker load information
Expand Down
1 change: 1 addition & 0 deletions docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Controls how requests are distributed across workers.
|--------|-------------|---------|
| `--dp-aware` | Enable data parallelism aware scheduling | `false` |
| `--enable-igw` | Enable IGW (Inference Gateway) mode for multi-model support | `false` |
| `--dp-minimum-tokens-scheduler` | Enable minimum tokens scheduler for data parallel group | `false` |

---

Expand Down
5 changes: 5 additions & 0 deletions model_gateway/src/config/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ impl RouterConfigBuilder {
self
}

pub fn dp_minimum_tokens_scheduler(mut self, enable: bool) -> Self {
self.config.dp_minimum_tokens_scheduler = enable;
self
}

// ==================== Option Setters ====================
// Accept Option<T> and only set if Some

Expand Down
3 changes: 3 additions & 0 deletions model_gateway/src/config/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub struct RouterConfig {
#[serde(default = "default_load_monitor_interval_secs")]
pub load_monitor_interval_secs: u64,
pub dp_aware: bool,
#[serde(default)]
pub dp_minimum_tokens_scheduler: bool,
pub api_key: Option<String>,
pub discovery: Option<DiscoveryConfig>,
pub metrics: Option<MetricsConfig>,
Expand Down Expand Up @@ -536,6 +538,7 @@ impl Default for RouterConfig {
worker_startup_check_interval_secs: 30,
load_monitor_interval_secs: 10,
dp_aware: false,
dp_minimum_tokens_scheduler: false,
api_key: None,
discovery: None,
metrics: None,
Expand Down
2 changes: 2 additions & 0 deletions model_gateway/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub mod token_bucket;
pub mod worker;
pub mod worker_builder;
pub mod worker_event;
pub mod worker_load;
pub mod worker_manager;
pub mod worker_registry;
pub mod worker_service;
Expand All @@ -47,6 +48,7 @@ pub use worker::{
DEFAULT_BOOTSTRAP_PORT, MOONCAKE_CONNECTOR,
};
pub use worker_builder::BasicWorkerBuilder;
pub use worker_load::WorkerLoadManager;
pub use worker_manager::{LoadMonitor, WorkerManager};
pub use worker_registry::{HashRing, WorkerRegistry};
pub use worker_service::WorkerService;
133 changes: 133 additions & 0 deletions model_gateway/src/core/worker_load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//! Worker load
//!
//! Record and manage the DP group load of workers.
use std::{collections::HashMap, fmt::Debug, sync::RwLock};

use tracing::debug;

use crate::core::Worker;

#[derive(Debug, Default)]
pub struct WorkerLoadManager {
// <worker, <dp_rank, loads>>
dp_cached_loads: RwLock<HashMap<String, HashMap<isize, isize>>>,
}

impl WorkerLoadManager {
pub fn new() -> Self {
Self {
dp_cached_loads: RwLock::new(HashMap::new()),
}
}

pub fn update_dp_loads(&self, loads: &HashMap<String, HashMap<isize, isize>>) {
debug!("WorkerLoadManager update_dp_loads map:{:?}", loads);
let mut cached = self
.dp_cached_loads
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
cached.extend(loads.iter().map(|(k, v)| (k.clone(), v.clone())));
}

pub fn select_and_increment_lowest_dp_load(
&self,
worker: &dyn Worker,
increment: isize,
) -> Option<isize> {
let mut cached = self
.dp_cached_loads
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let loads = cached.get_mut(worker.url())?;
let (&dp_rank, _) = loads.iter().min_by_key(|&(rank, load)| (*load, *rank))?;
if let Some(v) = loads.get_mut(&dp_rank) {
*v += increment;
}
Some(dp_rank)
}

pub fn remove_workers(&self, urls: &[String]) {
let mut cached = self
.dp_cached_loads
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
for url in urls {
cached.remove(url);
}
}
}

#[cfg(test)]
mod dp_load_manager_tests {
use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType};

#[test]
fn test_new_dp_load_manager_instance() {
let dp_load_manager = WorkerLoadManager::new();
let cached = dp_load_manager.dp_cached_loads.read().unwrap();
assert!(cached.is_empty());
}

#[test]
fn test_update_dp_load() {
let manager = WorkerLoadManager::new();
let mut loads = HashMap::new();

// insert worker1_load
let mut worker1_load = HashMap::new();
worker1_load.insert(0, 2);
worker1_load.insert(1, 1);
loads.insert("http://worker1:8080".to_string(), worker1_load);

// insert worker2.load
let mut worker2_load = HashMap::new();
worker2_load.insert(0, 3);
loads.insert("http://worker2:8080".to_string(), worker2_load);

// update
manager.update_dp_loads(&loads);

// assert
let cached = manager.dp_cached_loads.read().unwrap();
assert_eq!(cached.len(), 2);

let worker2_cache = cached.get("http://worker2:8080").unwrap();
assert_eq!(worker2_cache.get(&0), Some(&3));
}

#[test]
fn test_select_and_increment_lowest_dp_load_multiple() {
let worker = BasicWorkerBuilder::new("http://worker:8080")
.worker_type(WorkerType::Regular)
.api_key("test_key")
.build();

let manager = WorkerLoadManager::new();
let mut loads = HashMap::new();
let mut worker_load = HashMap::new();
worker_load.insert(0, 10);
worker_load.insert(1, 3);
worker_load.insert(2, 7);
loads.insert(worker.url().to_string(), worker_load);
manager.update_dp_loads(&loads);

let selected = manager.select_and_increment_lowest_dp_load(&worker, 4);

assert_eq!(selected, Some(1));
let cached = manager.dp_cached_loads.read().unwrap();
assert_eq!(*cached.get(worker.url()).unwrap().get(&1).unwrap(), 3 + 4);
}

#[test]
fn test_select_and_increment_lowest_dp_load_none_worker() {
let worker = BasicWorkerBuilder::new("http://nonexist:8080")
.worker_type(WorkerType::Regular)
.api_key("test")
.build();

let manager = WorkerLoadManager::new();
let result = manager.select_and_increment_lowest_dp_load(&worker, 1);
assert_eq!(result, None);
}
}
17 changes: 14 additions & 3 deletions model_gateway/src/core/worker_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use tracing::{debug, info};
use crate::{
core::{
metrics_aggregator::{self, MetricPack},
ConnectionMode, Worker, WorkerRegistry, WorkerType,
ConnectionMode, Worker, WorkerLoadManager, WorkerRegistry, WorkerType,
},
policies::PolicyRegistry,
};
Expand Down Expand Up @@ -307,6 +307,7 @@ impl WorkerManager {
pub struct LoadMonitor {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
pub worker_load_manager: Arc<WorkerLoadManager>,
client: reqwest::Client,
default_interval: Duration,
tx: watch::Sender<HashMap<String, WorkerLoadResponse>>,
Expand All @@ -327,6 +328,7 @@ impl LoadMonitor {
Self {
worker_registry,
policy_registry,
worker_load_manager: Arc::new(WorkerLoadManager::new()),
client,
default_interval: Duration::from_secs(default_interval_secs),
tx,
Expand Down Expand Up @@ -356,6 +358,7 @@ impl LoadMonitor {

let worker_registry = Arc::clone(&self.worker_registry);
let policy_registry = Arc::clone(&self.policy_registry);
let worker_load_manager = Arc::clone(&self.worker_load_manager);
let client = self.client.clone();
let tx = self.tx.clone();
let group_key = key.clone();
Expand All @@ -369,6 +372,7 @@ impl LoadMonitor {
group_key,
worker_registry,
policy_registry,
worker_load_manager,
client,
interval,
tx,
Expand Down Expand Up @@ -402,6 +406,8 @@ impl LoadMonitor {
map.remove(url);
}
});
// Also remove from worker load manager's DP cached loads
self.worker_load_manager.remove_workers(worker_urls);
}
}

Expand Down Expand Up @@ -433,6 +439,7 @@ impl LoadMonitor {
group_key: WorkerGroupKey,
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
worker_load_manager: Arc<WorkerLoadManager>,
client: reqwest::Client,
interval: Duration,
tx: watch::Sender<HashMap<String, WorkerLoadResponse>>,
Expand All @@ -443,7 +450,7 @@ impl LoadMonitor {
interval_timer.tick().await;

let power_of_two_policies = policy_registry.get_all_power_of_two_policies();
if power_of_two_policies.is_empty() {
if power_of_two_policies.is_empty() && policy_registry.get_dp_rank_policy().is_none() {
debug!("No PowerOfTwo policies found, skipping load fetch for group {group_key}");
continue;
}
Expand Down Expand Up @@ -486,9 +493,12 @@ impl LoadMonitor {

// Collect successful loads
let mut group_loads: HashMap<String, WorkerLoadResponse> = HashMap::new();
let mut group_dp_loads: HashMap<String, HashMap<isize, isize>> = HashMap::new();
for (url, response) in results {
if let Some(load) = response {
group_loads.insert(url, load);
group_loads.insert(url.clone(), load.clone());
let dp_rank_loads = load.dp_rank_loads();
group_dp_loads.insert(url, dp_rank_loads);
}
}

Expand All @@ -507,6 +517,7 @@ impl LoadMonitor {
for policy in &power_of_two_policies {
policy.update_loads(&group_loads);
}
worker_load_manager.update_dp_loads(&group_dp_loads);

// Atomically merge into the shared watch channel.
// Remove all group URLs first to clear stale entries from workers
Expand Down
5 changes: 5 additions & 0 deletions model_gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ struct CliArgs {
#[arg(long, default_value_t = false, help_heading = "Routing Policy")]
enable_igw: bool,

/// Enable minimum tokens scheduler for data parallel group
#[arg(long, default_value_t = false, help_heading = "Routing Policy")]
dp_minimum_tokens_scheduler: bool,

// ==================== PD Disaggregation ====================
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long, default_value_t = false, help_heading = "PD Disaggregation")]
Expand Down Expand Up @@ -1227,6 +1231,7 @@ impl CliArgs {
.enable_wasm(self.enable_wasm)
.maybe_storage_hook_wasm_path(self.storage_hook_wasm_path.as_deref())
.igw(self.enable_igw)
.dp_minimum_tokens_scheduler(self.dp_minimum_tokens_scheduler)
.maybe_server_cert_and_key(self.tls_cert_path.as_ref(), self.tls_key_path.as_ref());

builder.build()
Expand Down
Loading
Loading